sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28UNESCAPED_SEQUENCES = { 29 "\\a": "\a", 30 "\\b": "\b", 31 "\\f": "\f", 32 "\\n": "\n", 33 "\\r": "\r", 34 "\\t": "\t", 35 "\\v": "\v", 36 "\\\\": "\\", 37} 38 39 40class Dialects(str, Enum): 41 """Dialects supported by SQLGLot.""" 42 43 DIALECT = "" 44 45 ATHENA = "athena" 46 BIGQUERY = "bigquery" 47 CLICKHOUSE = "clickhouse" 48 DATABRICKS = "databricks" 49 DORIS = "doris" 50 DRILL = "drill" 51 DUCKDB = "duckdb" 52 HIVE = "hive" 53 MATERIALIZE = "materialize" 54 MYSQL = "mysql" 55 ORACLE = "oracle" 56 POSTGRES = "postgres" 57 PRESTO = "presto" 58 PRQL = "prql" 59 REDSHIFT = "redshift" 60 RISINGWAVE = "risingwave" 61 SNOWFLAKE = "snowflake" 62 SPARK = "spark" 63 SPARK2 = "spark2" 64 SQLITE = "sqlite" 65 STARROCKS = "starrocks" 66 TABLEAU = "tableau" 67 TERADATA = "teradata" 68 TRINO = "trino" 69 TSQL = "tsql" 70 71 72class NormalizationStrategy(str, AutoName): 73 """Specifies the strategy according to which identifiers should be normalized.""" 74 75 LOWERCASE = auto() 76 """Unquoted identifiers are lowercased.""" 77 78 UPPERCASE = auto() 79 """Unquoted identifiers are uppercased.""" 80 81 CASE_SENSITIVE = auto() 82 """Always case-sensitive, regardless of quotes.""" 83 84 CASE_INSENSITIVE = auto() 85 """Always case-insensitive, regardless of quotes.""" 86 87 88class _Dialect(type): 89 classes: t.Dict[str, t.Type[Dialect]] = {} 90 91 def __eq__(cls, other: t.Any) -> bool: 92 if cls is other: 93 return True 94 if isinstance(other, str): 95 return cls is cls.get(other) 96 if isinstance(other, Dialect): 97 return cls is type(other) 98 99 return False 100 101 def __hash__(cls) -> int: 102 return hash(cls.__name__.lower()) 103 104 @classmethod 105 def __getitem__(cls, key: str) -> t.Type[Dialect]: 106 return cls.classes[key] 107 108 @classmethod 109 def get( 110 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 111 ) -> t.Optional[t.Type[Dialect]]: 112 return cls.classes.get(key, default) 113 114 def __new__(cls, clsname, bases, attrs): 115 klass = super().__new__(cls, clsname, bases, attrs) 116 enum = Dialects.__members__.get(clsname.upper()) 117 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 118 119 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 120 klass.FORMAT_TRIE = ( 121 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 122 ) 123 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 124 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 125 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 126 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 127 128 base = seq_get(bases, 0) 129 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 130 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 131 base_parser = (getattr(base, "parser_class", Parser),) 132 base_generator = (getattr(base, "generator_class", Generator),) 133 134 klass.tokenizer_class = klass.__dict__.get( 135 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 136 ) 137 klass.jsonpath_tokenizer_class = klass.__dict__.get( 138 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 139 ) 140 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 141 klass.generator_class = klass.__dict__.get( 142 "Generator", type("Generator", base_generator, {}) 143 ) 144 145 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 146 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 147 klass.tokenizer_class._IDENTIFIERS.items() 148 )[0] 149 150 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 151 return next( 152 ( 153 (s, e) 154 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 155 if t == token_type 156 ), 157 (None, None), 158 ) 159 160 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 161 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 162 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 163 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 164 165 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 166 klass.UNESCAPED_SEQUENCES = { 167 **UNESCAPED_SEQUENCES, 168 **klass.UNESCAPED_SEQUENCES, 169 } 170 171 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 172 173 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 174 175 if enum not in ("", "bigquery"): 176 klass.generator_class.SELECT_KINDS = () 177 178 if enum not in ("", "athena", "presto", "trino"): 179 klass.generator_class.TRY_SUPPORTED = False 180 klass.generator_class.SUPPORTS_UESCAPE = False 181 182 if enum not in ("", "databricks", "hive", "spark", "spark2"): 183 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 184 for modifier in ("cluster", "distribute", "sort"): 185 modifier_transforms.pop(modifier, None) 186 187 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 188 189 if enum not in ("", "doris", "mysql"): 190 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 191 TokenType.STRAIGHT_JOIN, 192 } 193 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 194 TokenType.STRAIGHT_JOIN, 195 } 196 197 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 198 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 199 TokenType.ANTI, 200 TokenType.SEMI, 201 } 202 203 return klass 204 205 206class Dialect(metaclass=_Dialect): 207 INDEX_OFFSET = 0 208 """The base index offset for arrays.""" 209 210 WEEK_OFFSET = 0 211 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 212 213 UNNEST_COLUMN_ONLY = False 214 """Whether `UNNEST` table aliases are treated as column aliases.""" 215 216 ALIAS_POST_TABLESAMPLE = False 217 """Whether the table alias comes after tablesample.""" 218 219 TABLESAMPLE_SIZE_IS_PERCENT = False 220 """Whether a size in the table sample clause represents percentage.""" 221 222 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 223 """Specifies the strategy according to which identifiers should be normalized.""" 224 225 IDENTIFIERS_CAN_START_WITH_DIGIT = False 226 """Whether an unquoted identifier can start with a digit.""" 227 228 DPIPE_IS_STRING_CONCAT = True 229 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 230 231 STRICT_STRING_CONCAT = False 232 """Whether `CONCAT`'s arguments must be strings.""" 233 234 SUPPORTS_USER_DEFINED_TYPES = True 235 """Whether user-defined data types are supported.""" 236 237 SUPPORTS_SEMI_ANTI_JOIN = True 238 """Whether `SEMI` or `ANTI` joins are supported.""" 239 240 SUPPORTS_COLUMN_JOIN_MARKS = False 241 """Whether the old-style outer join (+) syntax is supported.""" 242 243 COPY_PARAMS_ARE_CSV = True 244 """Separator of COPY statement parameters.""" 245 246 NORMALIZE_FUNCTIONS: bool | str = "upper" 247 """ 248 Determines how function names are going to be normalized. 249 Possible values: 250 "upper" or True: Convert names to uppercase. 251 "lower": Convert names to lowercase. 252 False: Disables function name normalization. 253 """ 254 255 LOG_BASE_FIRST: t.Optional[bool] = True 256 """ 257 Whether the base comes first in the `LOG` function. 258 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 259 """ 260 261 NULL_ORDERING = "nulls_are_small" 262 """ 263 Default `NULL` ordering method to use if not explicitly set. 264 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 265 """ 266 267 TYPED_DIVISION = False 268 """ 269 Whether the behavior of `a / b` depends on the types of `a` and `b`. 270 False means `a / b` is always float division. 271 True means `a / b` is integer division if both `a` and `b` are integers. 272 """ 273 274 SAFE_DIVISION = False 275 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 276 277 CONCAT_COALESCE = False 278 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 279 280 HEX_LOWERCASE = False 281 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 282 283 DATE_FORMAT = "'%Y-%m-%d'" 284 DATEINT_FORMAT = "'%Y%m%d'" 285 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 286 287 TIME_MAPPING: t.Dict[str, str] = {} 288 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 289 290 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 291 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 292 FORMAT_MAPPING: t.Dict[str, str] = {} 293 """ 294 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 295 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 296 """ 297 298 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 299 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 300 301 PSEUDOCOLUMNS: t.Set[str] = set() 302 """ 303 Columns that are auto-generated by the engine corresponding to this dialect. 304 For example, such columns may be excluded from `SELECT *` queries. 305 """ 306 307 PREFER_CTE_ALIAS_COLUMN = False 308 """ 309 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 310 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 311 any projection aliases in the subquery. 312 313 For example, 314 WITH y(c) AS ( 315 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 316 ) SELECT c FROM y; 317 318 will be rewritten as 319 320 WITH y(c) AS ( 321 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 322 ) SELECT c FROM y; 323 """ 324 325 COPY_PARAMS_ARE_CSV = True 326 """ 327 Whether COPY statement parameters are separated by comma or whitespace 328 """ 329 330 FORCE_EARLY_ALIAS_REF_EXPANSION = False 331 """ 332 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 333 334 For example: 335 WITH data AS ( 336 SELECT 337 1 AS id, 338 2 AS my_id 339 ) 340 SELECT 341 id AS my_id 342 FROM 343 data 344 WHERE 345 my_id = 1 346 GROUP BY 347 my_id, 348 HAVING 349 my_id = 1 350 351 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 352 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 353 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 354 """ 355 356 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 357 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 358 359 # --- Autofilled --- 360 361 tokenizer_class = Tokenizer 362 jsonpath_tokenizer_class = JSONPathTokenizer 363 parser_class = Parser 364 generator_class = Generator 365 366 # A trie of the time_mapping keys 367 TIME_TRIE: t.Dict = {} 368 FORMAT_TRIE: t.Dict = {} 369 370 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 371 INVERSE_TIME_TRIE: t.Dict = {} 372 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 373 INVERSE_FORMAT_TRIE: t.Dict = {} 374 375 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 376 377 # Delimiters for string literals and identifiers 378 QUOTE_START = "'" 379 QUOTE_END = "'" 380 IDENTIFIER_START = '"' 381 IDENTIFIER_END = '"' 382 383 # Delimiters for bit, hex, byte and unicode literals 384 BIT_START: t.Optional[str] = None 385 BIT_END: t.Optional[str] = None 386 HEX_START: t.Optional[str] = None 387 HEX_END: t.Optional[str] = None 388 BYTE_START: t.Optional[str] = None 389 BYTE_END: t.Optional[str] = None 390 UNICODE_START: t.Optional[str] = None 391 UNICODE_END: t.Optional[str] = None 392 393 DATE_PART_MAPPING = { 394 "Y": "YEAR", 395 "YY": "YEAR", 396 "YYY": "YEAR", 397 "YYYY": "YEAR", 398 "YR": "YEAR", 399 "YEARS": "YEAR", 400 "YRS": "YEAR", 401 "MM": "MONTH", 402 "MON": "MONTH", 403 "MONS": "MONTH", 404 "MONTHS": "MONTH", 405 "D": "DAY", 406 "DD": "DAY", 407 "DAYS": "DAY", 408 "DAYOFMONTH": "DAY", 409 "DAY OF WEEK": "DAYOFWEEK", 410 "WEEKDAY": "DAYOFWEEK", 411 "DOW": "DAYOFWEEK", 412 "DW": "DAYOFWEEK", 413 "WEEKDAY_ISO": "DAYOFWEEKISO", 414 "DOW_ISO": "DAYOFWEEKISO", 415 "DW_ISO": "DAYOFWEEKISO", 416 "DAY OF YEAR": "DAYOFYEAR", 417 "DOY": "DAYOFYEAR", 418 "DY": "DAYOFYEAR", 419 "W": "WEEK", 420 "WK": "WEEK", 421 "WEEKOFYEAR": "WEEK", 422 "WOY": "WEEK", 423 "WY": "WEEK", 424 "WEEK_ISO": "WEEKISO", 425 "WEEKOFYEARISO": "WEEKISO", 426 "WEEKOFYEAR_ISO": "WEEKISO", 427 "Q": "QUARTER", 428 "QTR": "QUARTER", 429 "QTRS": "QUARTER", 430 "QUARTERS": "QUARTER", 431 "H": "HOUR", 432 "HH": "HOUR", 433 "HR": "HOUR", 434 "HOURS": "HOUR", 435 "HRS": "HOUR", 436 "M": "MINUTE", 437 "MI": "MINUTE", 438 "MIN": "MINUTE", 439 "MINUTES": "MINUTE", 440 "MINS": "MINUTE", 441 "S": "SECOND", 442 "SEC": "SECOND", 443 "SECONDS": "SECOND", 444 "SECS": "SECOND", 445 "MS": "MILLISECOND", 446 "MSEC": "MILLISECOND", 447 "MSECS": "MILLISECOND", 448 "MSECOND": "MILLISECOND", 449 "MSECONDS": "MILLISECOND", 450 "MILLISEC": "MILLISECOND", 451 "MILLISECS": "MILLISECOND", 452 "MILLISECON": "MILLISECOND", 453 "MILLISECONDS": "MILLISECOND", 454 "US": "MICROSECOND", 455 "USEC": "MICROSECOND", 456 "USECS": "MICROSECOND", 457 "MICROSEC": "MICROSECOND", 458 "MICROSECS": "MICROSECOND", 459 "USECOND": "MICROSECOND", 460 "USECONDS": "MICROSECOND", 461 "MICROSECONDS": "MICROSECOND", 462 "NS": "NANOSECOND", 463 "NSEC": "NANOSECOND", 464 "NANOSEC": "NANOSECOND", 465 "NSECOND": "NANOSECOND", 466 "NSECONDS": "NANOSECOND", 467 "NANOSECS": "NANOSECOND", 468 "EPOCH_SECOND": "EPOCH", 469 "EPOCH_SECONDS": "EPOCH", 470 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 471 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 472 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 473 "TZH": "TIMEZONE_HOUR", 474 "TZM": "TIMEZONE_MINUTE", 475 "DEC": "DECADE", 476 "DECS": "DECADE", 477 "DECADES": "DECADE", 478 "MIL": "MILLENIUM", 479 "MILS": "MILLENIUM", 480 "MILLENIA": "MILLENIUM", 481 "C": "CENTURY", 482 "CENT": "CENTURY", 483 "CENTS": "CENTURY", 484 "CENTURIES": "CENTURY", 485 } 486 487 @classmethod 488 def get_or_raise(cls, dialect: DialectType) -> Dialect: 489 """ 490 Look up a dialect in the global dialect registry and return it if it exists. 491 492 Args: 493 dialect: The target dialect. If this is a string, it can be optionally followed by 494 additional key-value pairs that are separated by commas and are used to specify 495 dialect settings, such as whether the dialect's identifiers are case-sensitive. 496 497 Example: 498 >>> dialect = dialect_class = get_or_raise("duckdb") 499 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 500 501 Returns: 502 The corresponding Dialect instance. 503 """ 504 505 if not dialect: 506 return cls() 507 if isinstance(dialect, _Dialect): 508 return dialect() 509 if isinstance(dialect, Dialect): 510 return dialect 511 if isinstance(dialect, str): 512 try: 513 dialect_name, *kv_strings = dialect.split(",") 514 kv_pairs = (kv.split("=") for kv in kv_strings) 515 kwargs = {} 516 for pair in kv_pairs: 517 key = pair[0].strip() 518 value: t.Union[bool | str | None] = None 519 520 if len(pair) == 1: 521 # Default initialize standalone settings to True 522 value = True 523 elif len(pair) == 2: 524 value = pair[1].strip() 525 526 # Coerce the value to boolean if it matches to the truthy/falsy values below 527 value_lower = value.lower() 528 if value_lower in ("true", "1"): 529 value = True 530 elif value_lower in ("false", "0"): 531 value = False 532 533 kwargs[key] = value 534 535 except ValueError: 536 raise ValueError( 537 f"Invalid dialect format: '{dialect}'. " 538 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 539 ) 540 541 result = cls.get(dialect_name.strip()) 542 if not result: 543 from difflib import get_close_matches 544 545 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 546 if similar: 547 similar = f" Did you mean {similar}?" 548 549 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 550 551 return result(**kwargs) 552 553 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 554 555 @classmethod 556 def format_time( 557 cls, expression: t.Optional[str | exp.Expression] 558 ) -> t.Optional[exp.Expression]: 559 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 560 if isinstance(expression, str): 561 return exp.Literal.string( 562 # the time formats are quoted 563 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 564 ) 565 566 if expression and expression.is_string: 567 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 568 569 return expression 570 571 def __init__(self, **kwargs) -> None: 572 normalization_strategy = kwargs.pop("normalization_strategy", None) 573 574 if normalization_strategy is None: 575 self.normalization_strategy = self.NORMALIZATION_STRATEGY 576 else: 577 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 578 579 self.settings = kwargs 580 581 def __eq__(self, other: t.Any) -> bool: 582 # Does not currently take dialect state into account 583 return type(self) == other 584 585 def __hash__(self) -> int: 586 # Does not currently take dialect state into account 587 return hash(type(self)) 588 589 def normalize_identifier(self, expression: E) -> E: 590 """ 591 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 592 593 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 594 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 595 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 596 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 597 598 There are also dialects like Spark, which are case-insensitive even when quotes are 599 present, and dialects like MySQL, whose resolution rules match those employed by the 600 underlying operating system, for example they may always be case-sensitive in Linux. 601 602 Finally, the normalization behavior of some engines can even be controlled through flags, 603 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 604 605 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 606 that it can analyze queries in the optimizer and successfully capture their semantics. 607 """ 608 if ( 609 isinstance(expression, exp.Identifier) 610 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 611 and ( 612 not expression.quoted 613 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 614 ) 615 ): 616 expression.set( 617 "this", 618 ( 619 expression.this.upper() 620 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 621 else expression.this.lower() 622 ), 623 ) 624 625 return expression 626 627 def case_sensitive(self, text: str) -> bool: 628 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 629 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 630 return False 631 632 unsafe = ( 633 str.islower 634 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 635 else str.isupper 636 ) 637 return any(unsafe(char) for char in text) 638 639 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 640 """Checks if text can be identified given an identify option. 641 642 Args: 643 text: The text to check. 644 identify: 645 `"always"` or `True`: Always returns `True`. 646 `"safe"`: Only returns `True` if the identifier is case-insensitive. 647 648 Returns: 649 Whether the given text can be identified. 650 """ 651 if identify is True or identify == "always": 652 return True 653 654 if identify == "safe": 655 return not self.case_sensitive(text) 656 657 return False 658 659 def quote_identifier(self, expression: E, identify: bool = True) -> E: 660 """ 661 Adds quotes to a given identifier. 662 663 Args: 664 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 665 identify: If set to `False`, the quotes will only be added if the identifier is deemed 666 "unsafe", with respect to its characters and this dialect's normalization strategy. 667 """ 668 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 669 name = expression.this 670 expression.set( 671 "quoted", 672 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 673 ) 674 675 return expression 676 677 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 678 if isinstance(path, exp.Literal): 679 path_text = path.name 680 if path.is_number: 681 path_text = f"[{path_text}]" 682 try: 683 return parse_json_path(path_text, self) 684 except ParseError as e: 685 logger.warning(f"Invalid JSON path syntax. {str(e)}") 686 687 return path 688 689 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 690 return self.parser(**opts).parse(self.tokenize(sql), sql) 691 692 def parse_into( 693 self, expression_type: exp.IntoType, sql: str, **opts 694 ) -> t.List[t.Optional[exp.Expression]]: 695 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 696 697 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 698 return self.generator(**opts).generate(expression, copy=copy) 699 700 def transpile(self, sql: str, **opts) -> t.List[str]: 701 return [ 702 self.generate(expression, copy=False, **opts) if expression else "" 703 for expression in self.parse(sql) 704 ] 705 706 def tokenize(self, sql: str) -> t.List[Token]: 707 return self.tokenizer.tokenize(sql) 708 709 @property 710 def tokenizer(self) -> Tokenizer: 711 return self.tokenizer_class(dialect=self) 712 713 @property 714 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 715 return self.jsonpath_tokenizer_class(dialect=self) 716 717 def parser(self, **opts) -> Parser: 718 return self.parser_class(dialect=self, **opts) 719 720 def generator(self, **opts) -> Generator: 721 return self.generator_class(dialect=self, **opts) 722 723 724DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 725 726 727def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 728 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 729 730 731def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 732 if expression.args.get("accuracy"): 733 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 734 return self.func("APPROX_COUNT_DISTINCT", expression.this) 735 736 737def if_sql( 738 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 739) -> t.Callable[[Generator, exp.If], str]: 740 def _if_sql(self: Generator, expression: exp.If) -> str: 741 return self.func( 742 name, 743 expression.this, 744 expression.args.get("true"), 745 expression.args.get("false") or false_value, 746 ) 747 748 return _if_sql 749 750 751def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 752 this = expression.this 753 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 754 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 755 756 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 757 758 759def inline_array_sql(self: Generator, expression: exp.Array) -> str: 760 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 761 762 763def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 764 elem = seq_get(expression.expressions, 0) 765 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 766 return self.func("ARRAY", elem) 767 return inline_array_sql(self, expression) 768 769 770def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 771 return self.like_sql( 772 exp.Like( 773 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 774 ) 775 ) 776 777 778def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 779 zone = self.sql(expression, "this") 780 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 781 782 783def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 784 if expression.args.get("recursive"): 785 self.unsupported("Recursive CTEs are unsupported") 786 expression.args["recursive"] = False 787 return self.with_sql(expression) 788 789 790def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 791 n = self.sql(expression, "this") 792 d = self.sql(expression, "expression") 793 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 794 795 796def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 797 self.unsupported("TABLESAMPLE unsupported") 798 return self.sql(expression.this) 799 800 801def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 802 self.unsupported("PIVOT unsupported") 803 return "" 804 805 806def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 807 return self.cast_sql(expression) 808 809 810def no_comment_column_constraint_sql( 811 self: Generator, expression: exp.CommentColumnConstraint 812) -> str: 813 self.unsupported("CommentColumnConstraint unsupported") 814 return "" 815 816 817def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 818 self.unsupported("MAP_FROM_ENTRIES unsupported") 819 return "" 820 821 822def str_position_sql( 823 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 824) -> str: 825 this = self.sql(expression, "this") 826 substr = self.sql(expression, "substr") 827 position = self.sql(expression, "position") 828 instance = expression.args.get("instance") if generate_instance else None 829 position_offset = "" 830 831 if position: 832 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 833 this = self.func("SUBSTR", this, position) 834 position_offset = f" + {position} - 1" 835 836 return self.func("STRPOS", this, substr, instance) + position_offset 837 838 839def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 840 return ( 841 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 842 ) 843 844 845def var_map_sql( 846 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 847) -> str: 848 keys = expression.args["keys"] 849 values = expression.args["values"] 850 851 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 852 self.unsupported("Cannot convert array columns into map.") 853 return self.func(map_func_name, keys, values) 854 855 args = [] 856 for key, value in zip(keys.expressions, values.expressions): 857 args.append(self.sql(key)) 858 args.append(self.sql(value)) 859 860 return self.func(map_func_name, *args) 861 862 863def build_formatted_time( 864 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 865) -> t.Callable[[t.List], E]: 866 """Helper used for time expressions. 867 868 Args: 869 exp_class: the expression class to instantiate. 870 dialect: target sql dialect. 871 default: the default format, True being time. 872 873 Returns: 874 A callable that can be used to return the appropriately formatted time expression. 875 """ 876 877 def _builder(args: t.List): 878 return exp_class( 879 this=seq_get(args, 0), 880 format=Dialect[dialect].format_time( 881 seq_get(args, 1) 882 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 883 ), 884 ) 885 886 return _builder 887 888 889def time_format( 890 dialect: DialectType = None, 891) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 892 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 893 """ 894 Returns the time format for a given expression, unless it's equivalent 895 to the default time format of the dialect of interest. 896 """ 897 time_format = self.format_time(expression) 898 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 899 900 return _time_format 901 902 903def build_date_delta( 904 exp_class: t.Type[E], 905 unit_mapping: t.Optional[t.Dict[str, str]] = None, 906 default_unit: t.Optional[str] = "DAY", 907) -> t.Callable[[t.List], E]: 908 def _builder(args: t.List) -> E: 909 unit_based = len(args) == 3 910 this = args[2] if unit_based else seq_get(args, 0) 911 unit = None 912 if unit_based or default_unit: 913 unit = args[0] if unit_based else exp.Literal.string(default_unit) 914 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 915 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 916 917 return _builder 918 919 920def build_date_delta_with_interval( 921 expression_class: t.Type[E], 922) -> t.Callable[[t.List], t.Optional[E]]: 923 def _builder(args: t.List) -> t.Optional[E]: 924 if len(args) < 2: 925 return None 926 927 interval = args[1] 928 929 if not isinstance(interval, exp.Interval): 930 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 931 932 expression = interval.this 933 if expression and expression.is_string: 934 expression = exp.Literal.number(expression.this) 935 936 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 937 938 return _builder 939 940 941def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 942 unit = seq_get(args, 0) 943 this = seq_get(args, 1) 944 945 if isinstance(this, exp.Cast) and this.is_type("date"): 946 return exp.DateTrunc(unit=unit, this=this) 947 return exp.TimestampTrunc(this=this, unit=unit) 948 949 950def date_add_interval_sql( 951 data_type: str, kind: str 952) -> t.Callable[[Generator, exp.Expression], str]: 953 def func(self: Generator, expression: exp.Expression) -> str: 954 this = self.sql(expression, "this") 955 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 956 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 957 958 return func 959 960 961def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 962 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 963 args = [unit_to_str(expression), expression.this] 964 if zone: 965 args.append(expression.args.get("zone")) 966 return self.func("DATE_TRUNC", *args) 967 968 return _timestamptrunc_sql 969 970 971def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 972 zone = expression.args.get("zone") 973 if not zone: 974 from sqlglot.optimizer.annotate_types import annotate_types 975 976 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 977 return self.sql(exp.cast(expression.this, target_type)) 978 if zone.name.lower() in TIMEZONES: 979 return self.sql( 980 exp.AtTimeZone( 981 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 982 zone=zone, 983 ) 984 ) 985 return self.func("TIMESTAMP", expression.this, zone) 986 987 988def no_time_sql(self: Generator, expression: exp.Time) -> str: 989 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 990 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 991 expr = exp.cast( 992 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 993 ) 994 return self.sql(expr) 995 996 997def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 998 this = expression.this 999 expr = expression.expression 1000 1001 if expr.name.lower() in TIMEZONES: 1002 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1003 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1004 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1005 return self.sql(this) 1006 1007 this = exp.cast(this, exp.DataType.Type.DATE) 1008 expr = exp.cast(expr, exp.DataType.Type.TIME) 1009 1010 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1011 1012 1013def locate_to_strposition(args: t.List) -> exp.Expression: 1014 return exp.StrPosition( 1015 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1016 ) 1017 1018 1019def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1020 return self.func( 1021 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1022 ) 1023 1024 1025def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1026 return self.sql( 1027 exp.Substring( 1028 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1029 ) 1030 ) 1031 1032 1033def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1034 return self.sql( 1035 exp.Substring( 1036 this=expression.this, 1037 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1038 ) 1039 ) 1040 1041 1042def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 1043 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 1044 1045 1046def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1047 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1048 1049 1050# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1051def encode_decode_sql( 1052 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1053) -> str: 1054 charset = expression.args.get("charset") 1055 if charset and charset.name.lower() != "utf-8": 1056 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1057 1058 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1059 1060 1061def min_or_least(self: Generator, expression: exp.Min) -> str: 1062 name = "LEAST" if expression.expressions else "MIN" 1063 return rename_func(name)(self, expression) 1064 1065 1066def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1067 name = "GREATEST" if expression.expressions else "MAX" 1068 return rename_func(name)(self, expression) 1069 1070 1071def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1072 cond = expression.this 1073 1074 if isinstance(expression.this, exp.Distinct): 1075 cond = expression.this.expressions[0] 1076 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1077 1078 return self.func("sum", exp.func("if", cond, 1, 0)) 1079 1080 1081def trim_sql(self: Generator, expression: exp.Trim) -> str: 1082 target = self.sql(expression, "this") 1083 trim_type = self.sql(expression, "position") 1084 remove_chars = self.sql(expression, "expression") 1085 collation = self.sql(expression, "collation") 1086 1087 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1088 if not remove_chars and not collation: 1089 return self.trim_sql(expression) 1090 1091 trim_type = f"{trim_type} " if trim_type else "" 1092 remove_chars = f"{remove_chars} " if remove_chars else "" 1093 from_part = "FROM " if trim_type or remove_chars else "" 1094 collation = f" COLLATE {collation}" if collation else "" 1095 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1096 1097 1098def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1099 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1100 1101 1102def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1103 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1104 1105 1106def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1107 delim, *rest_args = expression.expressions 1108 return self.sql( 1109 reduce( 1110 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1111 rest_args, 1112 ) 1113 ) 1114 1115 1116def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1117 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1118 if bad_args: 1119 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1120 1121 return self.func( 1122 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1123 ) 1124 1125 1126def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1127 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1128 if bad_args: 1129 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1130 1131 return self.func( 1132 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1133 ) 1134 1135 1136def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1137 names = [] 1138 for agg in aggregations: 1139 if isinstance(agg, exp.Alias): 1140 names.append(agg.alias) 1141 else: 1142 """ 1143 This case corresponds to aggregations without aliases being used as suffixes 1144 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1145 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1146 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1147 """ 1148 agg_all_unquoted = agg.transform( 1149 lambda node: ( 1150 exp.Identifier(this=node.name, quoted=False) 1151 if isinstance(node, exp.Identifier) 1152 else node 1153 ) 1154 ) 1155 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1156 1157 return names 1158 1159 1160def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1161 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1162 1163 1164# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1165def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1166 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1167 1168 1169def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1170 return self.func("MAX", expression.this) 1171 1172 1173def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1174 a = self.sql(expression.left) 1175 b = self.sql(expression.right) 1176 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1177 1178 1179def is_parse_json(expression: exp.Expression) -> bool: 1180 return isinstance(expression, exp.ParseJSON) or ( 1181 isinstance(expression, exp.Cast) and expression.is_type("json") 1182 ) 1183 1184 1185def isnull_to_is_null(args: t.List) -> exp.Expression: 1186 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1187 1188 1189def generatedasidentitycolumnconstraint_sql( 1190 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1191) -> str: 1192 start = self.sql(expression, "start") or "1" 1193 increment = self.sql(expression, "increment") or "1" 1194 return f"IDENTITY({start}, {increment})" 1195 1196 1197def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1198 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1199 if expression.args.get("count"): 1200 self.unsupported(f"Only two arguments are supported in function {name}.") 1201 1202 return self.func(name, expression.this, expression.expression) 1203 1204 return _arg_max_or_min_sql 1205 1206 1207def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1208 this = expression.this.copy() 1209 1210 return_type = expression.return_type 1211 if return_type.is_type(exp.DataType.Type.DATE): 1212 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1213 # can truncate timestamp strings, because some dialects can't cast them to DATE 1214 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1215 1216 expression.this.replace(exp.cast(this, return_type)) 1217 return expression 1218 1219 1220def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1221 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1222 if cast and isinstance(expression, exp.TsOrDsAdd): 1223 expression = ts_or_ds_add_cast(expression) 1224 1225 return self.func( 1226 name, 1227 unit_to_var(expression), 1228 expression.expression, 1229 expression.this, 1230 ) 1231 1232 return _delta_sql 1233 1234 1235def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1236 unit = expression.args.get("unit") 1237 1238 if isinstance(unit, exp.Placeholder): 1239 return unit 1240 if unit: 1241 return exp.Literal.string(unit.name) 1242 return exp.Literal.string(default) if default else None 1243 1244 1245def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1246 unit = expression.args.get("unit") 1247 1248 if isinstance(unit, (exp.Var, exp.Placeholder)): 1249 return unit 1250 return exp.Var(this=default) if default else None 1251 1252 1253@t.overload 1254def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1255 pass 1256 1257 1258@t.overload 1259def map_date_part( 1260 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1261) -> t.Optional[exp.Expression]: 1262 pass 1263 1264 1265def map_date_part(part, dialect: DialectType = Dialect): 1266 mapped = ( 1267 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1268 ) 1269 return exp.var(mapped) if mapped else part 1270 1271 1272def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1273 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1274 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1275 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1276 1277 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1278 1279 1280def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1281 """Remove table refs from columns in when statements.""" 1282 alias = expression.this.args.get("alias") 1283 1284 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1285 return self.dialect.normalize_identifier(identifier).name if identifier else None 1286 1287 targets = {normalize(expression.this.this)} 1288 1289 if alias: 1290 targets.add(normalize(alias.this)) 1291 1292 for when in expression.expressions: 1293 when.transform( 1294 lambda node: ( 1295 exp.column(node.this) 1296 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1297 else node 1298 ), 1299 copy=False, 1300 ) 1301 1302 return self.merge_sql(expression) 1303 1304 1305def build_json_extract_path( 1306 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1307) -> t.Callable[[t.List], F]: 1308 def _builder(args: t.List) -> F: 1309 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1310 for arg in args[1:]: 1311 if not isinstance(arg, exp.Literal): 1312 # We use the fallback parser because we can't really transpile non-literals safely 1313 return expr_type.from_arg_list(args) 1314 1315 text = arg.name 1316 if is_int(text): 1317 index = int(text) 1318 segments.append( 1319 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1320 ) 1321 else: 1322 segments.append(exp.JSONPathKey(this=text)) 1323 1324 # This is done to avoid failing in the expression validator due to the arg count 1325 del args[2:] 1326 return expr_type( 1327 this=seq_get(args, 0), 1328 expression=exp.JSONPath(expressions=segments), 1329 only_json_types=arrow_req_json_type, 1330 ) 1331 1332 return _builder 1333 1334 1335def json_extract_segments( 1336 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1337) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1338 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1339 path = expression.expression 1340 if not isinstance(path, exp.JSONPath): 1341 return rename_func(name)(self, expression) 1342 1343 segments = [] 1344 for segment in path.expressions: 1345 path = self.sql(segment) 1346 if path: 1347 if isinstance(segment, exp.JSONPathPart) and ( 1348 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1349 ): 1350 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1351 1352 segments.append(path) 1353 1354 if op: 1355 return f" {op} ".join([self.sql(expression.this), *segments]) 1356 return self.func(name, expression.this, *segments) 1357 1358 return _json_extract_segments 1359 1360 1361def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1362 if isinstance(expression.this, exp.JSONPathWildcard): 1363 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1364 1365 return expression.name 1366 1367 1368def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1369 cond = expression.expression 1370 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1371 alias = cond.expressions[0] 1372 cond = cond.this 1373 elif isinstance(cond, exp.Predicate): 1374 alias = "_u" 1375 else: 1376 self.unsupported("Unsupported filter condition") 1377 return "" 1378 1379 unnest = exp.Unnest(expressions=[expression.this]) 1380 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1381 return self.sql(exp.Array(expressions=[filtered])) 1382 1383 1384def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1385 return self.func( 1386 "TO_NUMBER", 1387 expression.this, 1388 expression.args.get("format"), 1389 expression.args.get("nlsparam"), 1390 ) 1391 1392 1393def build_default_decimal_type( 1394 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1395) -> t.Callable[[exp.DataType], exp.DataType]: 1396 def _builder(dtype: exp.DataType) -> exp.DataType: 1397 if dtype.expressions or precision is None: 1398 return dtype 1399 1400 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1401 return exp.DataType.build(f"DECIMAL({params})") 1402 1403 return _builder 1404 1405 1406def build_timestamp_from_parts(args: t.List) -> exp.Func: 1407 if len(args) == 2: 1408 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1409 # so we parse this into Anonymous for now instead of introducing complexity 1410 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1411 1412 return exp.TimestampFromParts.from_arg_list(args) 1413 1414 1415def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1416 return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
41class Dialects(str, Enum): 42 """Dialects supported by SQLGLot.""" 43 44 DIALECT = "" 45 46 ATHENA = "athena" 47 BIGQUERY = "bigquery" 48 CLICKHOUSE = "clickhouse" 49 DATABRICKS = "databricks" 50 DORIS = "doris" 51 DRILL = "drill" 52 DUCKDB = "duckdb" 53 HIVE = "hive" 54 MATERIALIZE = "materialize" 55 MYSQL = "mysql" 56 ORACLE = "oracle" 57 POSTGRES = "postgres" 58 PRESTO = "presto" 59 PRQL = "prql" 60 REDSHIFT = "redshift" 61 RISINGWAVE = "risingwave" 62 SNOWFLAKE = "snowflake" 63 SPARK = "spark" 64 SPARK2 = "spark2" 65 SQLITE = "sqlite" 66 STARROCKS = "starrocks" 67 TABLEAU = "tableau" 68 TERADATA = "teradata" 69 TRINO = "trino" 70 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
73class NormalizationStrategy(str, AutoName): 74 """Specifies the strategy according to which identifiers should be normalized.""" 75 76 LOWERCASE = auto() 77 """Unquoted identifiers are lowercased.""" 78 79 UPPERCASE = auto() 80 """Unquoted identifiers are uppercased.""" 81 82 CASE_SENSITIVE = auto() 83 """Always case-sensitive, regardless of quotes.""" 84 85 CASE_INSENSITIVE = auto() 86 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
207class Dialect(metaclass=_Dialect): 208 INDEX_OFFSET = 0 209 """The base index offset for arrays.""" 210 211 WEEK_OFFSET = 0 212 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 213 214 UNNEST_COLUMN_ONLY = False 215 """Whether `UNNEST` table aliases are treated as column aliases.""" 216 217 ALIAS_POST_TABLESAMPLE = False 218 """Whether the table alias comes after tablesample.""" 219 220 TABLESAMPLE_SIZE_IS_PERCENT = False 221 """Whether a size in the table sample clause represents percentage.""" 222 223 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 224 """Specifies the strategy according to which identifiers should be normalized.""" 225 226 IDENTIFIERS_CAN_START_WITH_DIGIT = False 227 """Whether an unquoted identifier can start with a digit.""" 228 229 DPIPE_IS_STRING_CONCAT = True 230 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 231 232 STRICT_STRING_CONCAT = False 233 """Whether `CONCAT`'s arguments must be strings.""" 234 235 SUPPORTS_USER_DEFINED_TYPES = True 236 """Whether user-defined data types are supported.""" 237 238 SUPPORTS_SEMI_ANTI_JOIN = True 239 """Whether `SEMI` or `ANTI` joins are supported.""" 240 241 SUPPORTS_COLUMN_JOIN_MARKS = False 242 """Whether the old-style outer join (+) syntax is supported.""" 243 244 COPY_PARAMS_ARE_CSV = True 245 """Separator of COPY statement parameters.""" 246 247 NORMALIZE_FUNCTIONS: bool | str = "upper" 248 """ 249 Determines how function names are going to be normalized. 250 Possible values: 251 "upper" or True: Convert names to uppercase. 252 "lower": Convert names to lowercase. 253 False: Disables function name normalization. 254 """ 255 256 LOG_BASE_FIRST: t.Optional[bool] = True 257 """ 258 Whether the base comes first in the `LOG` function. 259 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 260 """ 261 262 NULL_ORDERING = "nulls_are_small" 263 """ 264 Default `NULL` ordering method to use if not explicitly set. 265 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 266 """ 267 268 TYPED_DIVISION = False 269 """ 270 Whether the behavior of `a / b` depends on the types of `a` and `b`. 271 False means `a / b` is always float division. 272 True means `a / b` is integer division if both `a` and `b` are integers. 273 """ 274 275 SAFE_DIVISION = False 276 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 277 278 CONCAT_COALESCE = False 279 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 280 281 HEX_LOWERCASE = False 282 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 283 284 DATE_FORMAT = "'%Y-%m-%d'" 285 DATEINT_FORMAT = "'%Y%m%d'" 286 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 287 288 TIME_MAPPING: t.Dict[str, str] = {} 289 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 290 291 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 292 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 293 FORMAT_MAPPING: t.Dict[str, str] = {} 294 """ 295 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 296 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 297 """ 298 299 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 300 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 301 302 PSEUDOCOLUMNS: t.Set[str] = set() 303 """ 304 Columns that are auto-generated by the engine corresponding to this dialect. 305 For example, such columns may be excluded from `SELECT *` queries. 306 """ 307 308 PREFER_CTE_ALIAS_COLUMN = False 309 """ 310 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 311 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 312 any projection aliases in the subquery. 313 314 For example, 315 WITH y(c) AS ( 316 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 317 ) SELECT c FROM y; 318 319 will be rewritten as 320 321 WITH y(c) AS ( 322 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 323 ) SELECT c FROM y; 324 """ 325 326 COPY_PARAMS_ARE_CSV = True 327 """ 328 Whether COPY statement parameters are separated by comma or whitespace 329 """ 330 331 FORCE_EARLY_ALIAS_REF_EXPANSION = False 332 """ 333 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 334 335 For example: 336 WITH data AS ( 337 SELECT 338 1 AS id, 339 2 AS my_id 340 ) 341 SELECT 342 id AS my_id 343 FROM 344 data 345 WHERE 346 my_id = 1 347 GROUP BY 348 my_id, 349 HAVING 350 my_id = 1 351 352 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 353 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 354 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 355 """ 356 357 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 358 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 359 360 # --- Autofilled --- 361 362 tokenizer_class = Tokenizer 363 jsonpath_tokenizer_class = JSONPathTokenizer 364 parser_class = Parser 365 generator_class = Generator 366 367 # A trie of the time_mapping keys 368 TIME_TRIE: t.Dict = {} 369 FORMAT_TRIE: t.Dict = {} 370 371 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 372 INVERSE_TIME_TRIE: t.Dict = {} 373 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 374 INVERSE_FORMAT_TRIE: t.Dict = {} 375 376 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 377 378 # Delimiters for string literals and identifiers 379 QUOTE_START = "'" 380 QUOTE_END = "'" 381 IDENTIFIER_START = '"' 382 IDENTIFIER_END = '"' 383 384 # Delimiters for bit, hex, byte and unicode literals 385 BIT_START: t.Optional[str] = None 386 BIT_END: t.Optional[str] = None 387 HEX_START: t.Optional[str] = None 388 HEX_END: t.Optional[str] = None 389 BYTE_START: t.Optional[str] = None 390 BYTE_END: t.Optional[str] = None 391 UNICODE_START: t.Optional[str] = None 392 UNICODE_END: t.Optional[str] = None 393 394 DATE_PART_MAPPING = { 395 "Y": "YEAR", 396 "YY": "YEAR", 397 "YYY": "YEAR", 398 "YYYY": "YEAR", 399 "YR": "YEAR", 400 "YEARS": "YEAR", 401 "YRS": "YEAR", 402 "MM": "MONTH", 403 "MON": "MONTH", 404 "MONS": "MONTH", 405 "MONTHS": "MONTH", 406 "D": "DAY", 407 "DD": "DAY", 408 "DAYS": "DAY", 409 "DAYOFMONTH": "DAY", 410 "DAY OF WEEK": "DAYOFWEEK", 411 "WEEKDAY": "DAYOFWEEK", 412 "DOW": "DAYOFWEEK", 413 "DW": "DAYOFWEEK", 414 "WEEKDAY_ISO": "DAYOFWEEKISO", 415 "DOW_ISO": "DAYOFWEEKISO", 416 "DW_ISO": "DAYOFWEEKISO", 417 "DAY OF YEAR": "DAYOFYEAR", 418 "DOY": "DAYOFYEAR", 419 "DY": "DAYOFYEAR", 420 "W": "WEEK", 421 "WK": "WEEK", 422 "WEEKOFYEAR": "WEEK", 423 "WOY": "WEEK", 424 "WY": "WEEK", 425 "WEEK_ISO": "WEEKISO", 426 "WEEKOFYEARISO": "WEEKISO", 427 "WEEKOFYEAR_ISO": "WEEKISO", 428 "Q": "QUARTER", 429 "QTR": "QUARTER", 430 "QTRS": "QUARTER", 431 "QUARTERS": "QUARTER", 432 "H": "HOUR", 433 "HH": "HOUR", 434 "HR": "HOUR", 435 "HOURS": "HOUR", 436 "HRS": "HOUR", 437 "M": "MINUTE", 438 "MI": "MINUTE", 439 "MIN": "MINUTE", 440 "MINUTES": "MINUTE", 441 "MINS": "MINUTE", 442 "S": "SECOND", 443 "SEC": "SECOND", 444 "SECONDS": "SECOND", 445 "SECS": "SECOND", 446 "MS": "MILLISECOND", 447 "MSEC": "MILLISECOND", 448 "MSECS": "MILLISECOND", 449 "MSECOND": "MILLISECOND", 450 "MSECONDS": "MILLISECOND", 451 "MILLISEC": "MILLISECOND", 452 "MILLISECS": "MILLISECOND", 453 "MILLISECON": "MILLISECOND", 454 "MILLISECONDS": "MILLISECOND", 455 "US": "MICROSECOND", 456 "USEC": "MICROSECOND", 457 "USECS": "MICROSECOND", 458 "MICROSEC": "MICROSECOND", 459 "MICROSECS": "MICROSECOND", 460 "USECOND": "MICROSECOND", 461 "USECONDS": "MICROSECOND", 462 "MICROSECONDS": "MICROSECOND", 463 "NS": "NANOSECOND", 464 "NSEC": "NANOSECOND", 465 "NANOSEC": "NANOSECOND", 466 "NSECOND": "NANOSECOND", 467 "NSECONDS": "NANOSECOND", 468 "NANOSECS": "NANOSECOND", 469 "EPOCH_SECOND": "EPOCH", 470 "EPOCH_SECONDS": "EPOCH", 471 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 472 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 473 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 474 "TZH": "TIMEZONE_HOUR", 475 "TZM": "TIMEZONE_MINUTE", 476 "DEC": "DECADE", 477 "DECS": "DECADE", 478 "DECADES": "DECADE", 479 "MIL": "MILLENIUM", 480 "MILS": "MILLENIUM", 481 "MILLENIA": "MILLENIUM", 482 "C": "CENTURY", 483 "CENT": "CENTURY", 484 "CENTS": "CENTURY", 485 "CENTURIES": "CENTURY", 486 } 487 488 @classmethod 489 def get_or_raise(cls, dialect: DialectType) -> Dialect: 490 """ 491 Look up a dialect in the global dialect registry and return it if it exists. 492 493 Args: 494 dialect: The target dialect. If this is a string, it can be optionally followed by 495 additional key-value pairs that are separated by commas and are used to specify 496 dialect settings, such as whether the dialect's identifiers are case-sensitive. 497 498 Example: 499 >>> dialect = dialect_class = get_or_raise("duckdb") 500 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 501 502 Returns: 503 The corresponding Dialect instance. 504 """ 505 506 if not dialect: 507 return cls() 508 if isinstance(dialect, _Dialect): 509 return dialect() 510 if isinstance(dialect, Dialect): 511 return dialect 512 if isinstance(dialect, str): 513 try: 514 dialect_name, *kv_strings = dialect.split(",") 515 kv_pairs = (kv.split("=") for kv in kv_strings) 516 kwargs = {} 517 for pair in kv_pairs: 518 key = pair[0].strip() 519 value: t.Union[bool | str | None] = None 520 521 if len(pair) == 1: 522 # Default initialize standalone settings to True 523 value = True 524 elif len(pair) == 2: 525 value = pair[1].strip() 526 527 # Coerce the value to boolean if it matches to the truthy/falsy values below 528 value_lower = value.lower() 529 if value_lower in ("true", "1"): 530 value = True 531 elif value_lower in ("false", "0"): 532 value = False 533 534 kwargs[key] = value 535 536 except ValueError: 537 raise ValueError( 538 f"Invalid dialect format: '{dialect}'. " 539 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 540 ) 541 542 result = cls.get(dialect_name.strip()) 543 if not result: 544 from difflib import get_close_matches 545 546 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 547 if similar: 548 similar = f" Did you mean {similar}?" 549 550 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 551 552 return result(**kwargs) 553 554 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 555 556 @classmethod 557 def format_time( 558 cls, expression: t.Optional[str | exp.Expression] 559 ) -> t.Optional[exp.Expression]: 560 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 561 if isinstance(expression, str): 562 return exp.Literal.string( 563 # the time formats are quoted 564 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 565 ) 566 567 if expression and expression.is_string: 568 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 569 570 return expression 571 572 def __init__(self, **kwargs) -> None: 573 normalization_strategy = kwargs.pop("normalization_strategy", None) 574 575 if normalization_strategy is None: 576 self.normalization_strategy = self.NORMALIZATION_STRATEGY 577 else: 578 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 579 580 self.settings = kwargs 581 582 def __eq__(self, other: t.Any) -> bool: 583 # Does not currently take dialect state into account 584 return type(self) == other 585 586 def __hash__(self) -> int: 587 # Does not currently take dialect state into account 588 return hash(type(self)) 589 590 def normalize_identifier(self, expression: E) -> E: 591 """ 592 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 593 594 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 595 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 596 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 597 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 598 599 There are also dialects like Spark, which are case-insensitive even when quotes are 600 present, and dialects like MySQL, whose resolution rules match those employed by the 601 underlying operating system, for example they may always be case-sensitive in Linux. 602 603 Finally, the normalization behavior of some engines can even be controlled through flags, 604 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 605 606 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 607 that it can analyze queries in the optimizer and successfully capture their semantics. 608 """ 609 if ( 610 isinstance(expression, exp.Identifier) 611 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 612 and ( 613 not expression.quoted 614 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 615 ) 616 ): 617 expression.set( 618 "this", 619 ( 620 expression.this.upper() 621 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 622 else expression.this.lower() 623 ), 624 ) 625 626 return expression 627 628 def case_sensitive(self, text: str) -> bool: 629 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 630 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 631 return False 632 633 unsafe = ( 634 str.islower 635 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 636 else str.isupper 637 ) 638 return any(unsafe(char) for char in text) 639 640 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 641 """Checks if text can be identified given an identify option. 642 643 Args: 644 text: The text to check. 645 identify: 646 `"always"` or `True`: Always returns `True`. 647 `"safe"`: Only returns `True` if the identifier is case-insensitive. 648 649 Returns: 650 Whether the given text can be identified. 651 """ 652 if identify is True or identify == "always": 653 return True 654 655 if identify == "safe": 656 return not self.case_sensitive(text) 657 658 return False 659 660 def quote_identifier(self, expression: E, identify: bool = True) -> E: 661 """ 662 Adds quotes to a given identifier. 663 664 Args: 665 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 666 identify: If set to `False`, the quotes will only be added if the identifier is deemed 667 "unsafe", with respect to its characters and this dialect's normalization strategy. 668 """ 669 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 670 name = expression.this 671 expression.set( 672 "quoted", 673 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 674 ) 675 676 return expression 677 678 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 679 if isinstance(path, exp.Literal): 680 path_text = path.name 681 if path.is_number: 682 path_text = f"[{path_text}]" 683 try: 684 return parse_json_path(path_text, self) 685 except ParseError as e: 686 logger.warning(f"Invalid JSON path syntax. {str(e)}") 687 688 return path 689 690 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 691 return self.parser(**opts).parse(self.tokenize(sql), sql) 692 693 def parse_into( 694 self, expression_type: exp.IntoType, sql: str, **opts 695 ) -> t.List[t.Optional[exp.Expression]]: 696 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 697 698 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 699 return self.generator(**opts).generate(expression, copy=copy) 700 701 def transpile(self, sql: str, **opts) -> t.List[str]: 702 return [ 703 self.generate(expression, copy=False, **opts) if expression else "" 704 for expression in self.parse(sql) 705 ] 706 707 def tokenize(self, sql: str) -> t.List[Token]: 708 return self.tokenizer.tokenize(sql) 709 710 @property 711 def tokenizer(self) -> Tokenizer: 712 return self.tokenizer_class(dialect=self) 713 714 @property 715 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 716 return self.jsonpath_tokenizer_class(dialect=self) 717 718 def parser(self, **opts) -> Parser: 719 return self.parser_class(dialect=self, **opts) 720 721 def generator(self, **opts) -> Generator: 722 return self.generator_class(dialect=self, **opts)
572 def __init__(self, **kwargs) -> None: 573 normalization_strategy = kwargs.pop("normalization_strategy", None) 574 575 if normalization_strategy is None: 576 self.normalization_strategy = self.NORMALIZATION_STRATEGY 577 else: 578 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 579 580 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
488 @classmethod 489 def get_or_raise(cls, dialect: DialectType) -> Dialect: 490 """ 491 Look up a dialect in the global dialect registry and return it if it exists. 492 493 Args: 494 dialect: The target dialect. If this is a string, it can be optionally followed by 495 additional key-value pairs that are separated by commas and are used to specify 496 dialect settings, such as whether the dialect's identifiers are case-sensitive. 497 498 Example: 499 >>> dialect = dialect_class = get_or_raise("duckdb") 500 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 501 502 Returns: 503 The corresponding Dialect instance. 504 """ 505 506 if not dialect: 507 return cls() 508 if isinstance(dialect, _Dialect): 509 return dialect() 510 if isinstance(dialect, Dialect): 511 return dialect 512 if isinstance(dialect, str): 513 try: 514 dialect_name, *kv_strings = dialect.split(",") 515 kv_pairs = (kv.split("=") for kv in kv_strings) 516 kwargs = {} 517 for pair in kv_pairs: 518 key = pair[0].strip() 519 value: t.Union[bool | str | None] = None 520 521 if len(pair) == 1: 522 # Default initialize standalone settings to True 523 value = True 524 elif len(pair) == 2: 525 value = pair[1].strip() 526 527 # Coerce the value to boolean if it matches to the truthy/falsy values below 528 value_lower = value.lower() 529 if value_lower in ("true", "1"): 530 value = True 531 elif value_lower in ("false", "0"): 532 value = False 533 534 kwargs[key] = value 535 536 except ValueError: 537 raise ValueError( 538 f"Invalid dialect format: '{dialect}'. " 539 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 540 ) 541 542 result = cls.get(dialect_name.strip()) 543 if not result: 544 from difflib import get_close_matches 545 546 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 547 if similar: 548 similar = f" Did you mean {similar}?" 549 550 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 551 552 return result(**kwargs) 553 554 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
556 @classmethod 557 def format_time( 558 cls, expression: t.Optional[str | exp.Expression] 559 ) -> t.Optional[exp.Expression]: 560 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 561 if isinstance(expression, str): 562 return exp.Literal.string( 563 # the time formats are quoted 564 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 565 ) 566 567 if expression and expression.is_string: 568 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 569 570 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
590 def normalize_identifier(self, expression: E) -> E: 591 """ 592 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 593 594 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 595 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 596 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 597 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 598 599 There are also dialects like Spark, which are case-insensitive even when quotes are 600 present, and dialects like MySQL, whose resolution rules match those employed by the 601 underlying operating system, for example they may always be case-sensitive in Linux. 602 603 Finally, the normalization behavior of some engines can even be controlled through flags, 604 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 605 606 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 607 that it can analyze queries in the optimizer and successfully capture their semantics. 608 """ 609 if ( 610 isinstance(expression, exp.Identifier) 611 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 612 and ( 613 not expression.quoted 614 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 615 ) 616 ): 617 expression.set( 618 "this", 619 ( 620 expression.this.upper() 621 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 622 else expression.this.lower() 623 ), 624 ) 625 626 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
628 def case_sensitive(self, text: str) -> bool: 629 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 630 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 631 return False 632 633 unsafe = ( 634 str.islower 635 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 636 else str.isupper 637 ) 638 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
640 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 641 """Checks if text can be identified given an identify option. 642 643 Args: 644 text: The text to check. 645 identify: 646 `"always"` or `True`: Always returns `True`. 647 `"safe"`: Only returns `True` if the identifier is case-insensitive. 648 649 Returns: 650 Whether the given text can be identified. 651 """ 652 if identify is True or identify == "always": 653 return True 654 655 if identify == "safe": 656 return not self.case_sensitive(text) 657 658 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
660 def quote_identifier(self, expression: E, identify: bool = True) -> E: 661 """ 662 Adds quotes to a given identifier. 663 664 Args: 665 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 666 identify: If set to `False`, the quotes will only be added if the identifier is deemed 667 "unsafe", with respect to its characters and this dialect's normalization strategy. 668 """ 669 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 670 name = expression.this 671 expression.set( 672 "quoted", 673 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 674 ) 675 676 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
678 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 679 if isinstance(path, exp.Literal): 680 path_text = path.name 681 if path.is_number: 682 path_text = f"[{path_text}]" 683 try: 684 return parse_json_path(path_text, self) 685 except ParseError as e: 686 logger.warning(f"Invalid JSON path syntax. {str(e)}") 687 688 return path
738def if_sql( 739 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 740) -> t.Callable[[Generator, exp.If], str]: 741 def _if_sql(self: Generator, expression: exp.If) -> str: 742 return self.func( 743 name, 744 expression.this, 745 expression.args.get("true"), 746 expression.args.get("false") or false_value, 747 ) 748 749 return _if_sql
752def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 753 this = expression.this 754 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 755 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 756 757 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
823def str_position_sql( 824 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 825) -> str: 826 this = self.sql(expression, "this") 827 substr = self.sql(expression, "substr") 828 position = self.sql(expression, "position") 829 instance = expression.args.get("instance") if generate_instance else None 830 position_offset = "" 831 832 if position: 833 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 834 this = self.func("SUBSTR", this, position) 835 position_offset = f" + {position} - 1" 836 837 return self.func("STRPOS", this, substr, instance) + position_offset
846def var_map_sql( 847 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 848) -> str: 849 keys = expression.args["keys"] 850 values = expression.args["values"] 851 852 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 853 self.unsupported("Cannot convert array columns into map.") 854 return self.func(map_func_name, keys, values) 855 856 args = [] 857 for key, value in zip(keys.expressions, values.expressions): 858 args.append(self.sql(key)) 859 args.append(self.sql(value)) 860 861 return self.func(map_func_name, *args)
864def build_formatted_time( 865 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 866) -> t.Callable[[t.List], E]: 867 """Helper used for time expressions. 868 869 Args: 870 exp_class: the expression class to instantiate. 871 dialect: target sql dialect. 872 default: the default format, True being time. 873 874 Returns: 875 A callable that can be used to return the appropriately formatted time expression. 876 """ 877 878 def _builder(args: t.List): 879 return exp_class( 880 this=seq_get(args, 0), 881 format=Dialect[dialect].format_time( 882 seq_get(args, 1) 883 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 884 ), 885 ) 886 887 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
890def time_format( 891 dialect: DialectType = None, 892) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 893 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 894 """ 895 Returns the time format for a given expression, unless it's equivalent 896 to the default time format of the dialect of interest. 897 """ 898 time_format = self.format_time(expression) 899 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 900 901 return _time_format
904def build_date_delta( 905 exp_class: t.Type[E], 906 unit_mapping: t.Optional[t.Dict[str, str]] = None, 907 default_unit: t.Optional[str] = "DAY", 908) -> t.Callable[[t.List], E]: 909 def _builder(args: t.List) -> E: 910 unit_based = len(args) == 3 911 this = args[2] if unit_based else seq_get(args, 0) 912 unit = None 913 if unit_based or default_unit: 914 unit = args[0] if unit_based else exp.Literal.string(default_unit) 915 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 916 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 917 918 return _builder
921def build_date_delta_with_interval( 922 expression_class: t.Type[E], 923) -> t.Callable[[t.List], t.Optional[E]]: 924 def _builder(args: t.List) -> t.Optional[E]: 925 if len(args) < 2: 926 return None 927 928 interval = args[1] 929 930 if not isinstance(interval, exp.Interval): 931 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 932 933 expression = interval.this 934 if expression and expression.is_string: 935 expression = exp.Literal.number(expression.this) 936 937 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 938 939 return _builder
951def date_add_interval_sql( 952 data_type: str, kind: str 953) -> t.Callable[[Generator, exp.Expression], str]: 954 def func(self: Generator, expression: exp.Expression) -> str: 955 this = self.sql(expression, "this") 956 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 957 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 958 959 return func
962def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 963 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 964 args = [unit_to_str(expression), expression.this] 965 if zone: 966 args.append(expression.args.get("zone")) 967 return self.func("DATE_TRUNC", *args) 968 969 return _timestamptrunc_sql
972def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 973 zone = expression.args.get("zone") 974 if not zone: 975 from sqlglot.optimizer.annotate_types import annotate_types 976 977 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 978 return self.sql(exp.cast(expression.this, target_type)) 979 if zone.name.lower() in TIMEZONES: 980 return self.sql( 981 exp.AtTimeZone( 982 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 983 zone=zone, 984 ) 985 ) 986 return self.func("TIMESTAMP", expression.this, zone)
989def no_time_sql(self: Generator, expression: exp.Time) -> str: 990 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 991 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 992 expr = exp.cast( 993 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 994 ) 995 return self.sql(expr)
998def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 999 this = expression.this 1000 expr = expression.expression 1001 1002 if expr.name.lower() in TIMEZONES: 1003 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1004 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1005 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1006 return self.sql(this) 1007 1008 this = exp.cast(this, exp.DataType.Type.DATE) 1009 expr = exp.cast(expr, exp.DataType.Type.TIME) 1010 1011 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1052def encode_decode_sql( 1053 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1054) -> str: 1055 charset = expression.args.get("charset") 1056 if charset and charset.name.lower() != "utf-8": 1057 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1058 1059 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1072def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1073 cond = expression.this 1074 1075 if isinstance(expression.this, exp.Distinct): 1076 cond = expression.this.expressions[0] 1077 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1078 1079 return self.func("sum", exp.func("if", cond, 1, 0))
1082def trim_sql(self: Generator, expression: exp.Trim) -> str: 1083 target = self.sql(expression, "this") 1084 trim_type = self.sql(expression, "position") 1085 remove_chars = self.sql(expression, "expression") 1086 collation = self.sql(expression, "collation") 1087 1088 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1089 if not remove_chars and not collation: 1090 return self.trim_sql(expression) 1091 1092 trim_type = f"{trim_type} " if trim_type else "" 1093 remove_chars = f"{remove_chars} " if remove_chars else "" 1094 from_part = "FROM " if trim_type or remove_chars else "" 1095 collation = f" COLLATE {collation}" if collation else "" 1096 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1117def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1118 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1119 if bad_args: 1120 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1121 1122 return self.func( 1123 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1124 )
1127def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1128 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1129 if bad_args: 1130 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1131 1132 return self.func( 1133 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1134 )
1137def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1138 names = [] 1139 for agg in aggregations: 1140 if isinstance(agg, exp.Alias): 1141 names.append(agg.alias) 1142 else: 1143 """ 1144 This case corresponds to aggregations without aliases being used as suffixes 1145 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1146 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1147 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1148 """ 1149 agg_all_unquoted = agg.transform( 1150 lambda node: ( 1151 exp.Identifier(this=node.name, quoted=False) 1152 if isinstance(node, exp.Identifier) 1153 else node 1154 ) 1155 ) 1156 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1157 1158 return names
1198def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1199 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1200 if expression.args.get("count"): 1201 self.unsupported(f"Only two arguments are supported in function {name}.") 1202 1203 return self.func(name, expression.this, expression.expression) 1204 1205 return _arg_max_or_min_sql
1208def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1209 this = expression.this.copy() 1210 1211 return_type = expression.return_type 1212 if return_type.is_type(exp.DataType.Type.DATE): 1213 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1214 # can truncate timestamp strings, because some dialects can't cast them to DATE 1215 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1216 1217 expression.this.replace(exp.cast(this, return_type)) 1218 return expression
1221def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1222 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1223 if cast and isinstance(expression, exp.TsOrDsAdd): 1224 expression = ts_or_ds_add_cast(expression) 1225 1226 return self.func( 1227 name, 1228 unit_to_var(expression), 1229 expression.expression, 1230 expression.this, 1231 ) 1232 1233 return _delta_sql
1236def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1237 unit = expression.args.get("unit") 1238 1239 if isinstance(unit, exp.Placeholder): 1240 return unit 1241 if unit: 1242 return exp.Literal.string(unit.name) 1243 return exp.Literal.string(default) if default else None
1273def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1274 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1275 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1276 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1277 1278 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1281def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1282 """Remove table refs from columns in when statements.""" 1283 alias = expression.this.args.get("alias") 1284 1285 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1286 return self.dialect.normalize_identifier(identifier).name if identifier else None 1287 1288 targets = {normalize(expression.this.this)} 1289 1290 if alias: 1291 targets.add(normalize(alias.this)) 1292 1293 for when in expression.expressions: 1294 when.transform( 1295 lambda node: ( 1296 exp.column(node.this) 1297 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1298 else node 1299 ), 1300 copy=False, 1301 ) 1302 1303 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1306def build_json_extract_path( 1307 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1308) -> t.Callable[[t.List], F]: 1309 def _builder(args: t.List) -> F: 1310 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1311 for arg in args[1:]: 1312 if not isinstance(arg, exp.Literal): 1313 # We use the fallback parser because we can't really transpile non-literals safely 1314 return expr_type.from_arg_list(args) 1315 1316 text = arg.name 1317 if is_int(text): 1318 index = int(text) 1319 segments.append( 1320 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1321 ) 1322 else: 1323 segments.append(exp.JSONPathKey(this=text)) 1324 1325 # This is done to avoid failing in the expression validator due to the arg count 1326 del args[2:] 1327 return expr_type( 1328 this=seq_get(args, 0), 1329 expression=exp.JSONPath(expressions=segments), 1330 only_json_types=arrow_req_json_type, 1331 ) 1332 1333 return _builder
1336def json_extract_segments( 1337 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1338) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1339 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1340 path = expression.expression 1341 if not isinstance(path, exp.JSONPath): 1342 return rename_func(name)(self, expression) 1343 1344 segments = [] 1345 for segment in path.expressions: 1346 path = self.sql(segment) 1347 if path: 1348 if isinstance(segment, exp.JSONPathPart) and ( 1349 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1350 ): 1351 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1352 1353 segments.append(path) 1354 1355 if op: 1356 return f" {op} ".join([self.sql(expression.this), *segments]) 1357 return self.func(name, expression.this, *segments) 1358 1359 return _json_extract_segments
1369def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1370 cond = expression.expression 1371 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1372 alias = cond.expressions[0] 1373 cond = cond.this 1374 elif isinstance(cond, exp.Predicate): 1375 alias = "_u" 1376 else: 1377 self.unsupported("Unsupported filter condition") 1378 return "" 1379 1380 unnest = exp.Unnest(expressions=[expression.this]) 1381 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1382 return self.sql(exp.Array(expressions=[filtered]))
1394def build_default_decimal_type( 1395 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1396) -> t.Callable[[exp.DataType], exp.DataType]: 1397 def _builder(dtype: exp.DataType) -> exp.DataType: 1398 if dtype.expressions or precision is None: 1399 return dtype 1400 1401 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1402 return exp.DataType.build(f"DECIMAL({params})") 1403 1404 return _builder
1407def build_timestamp_from_parts(args: t.List) -> exp.Func: 1408 if len(args) == 2: 1409 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1410 # so we parse this into Anonymous for now instead of introducing complexity 1411 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1412 1413 return exp.TimestampFromParts.from_arg_list(args)