sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import seq_get, SingleValuedMapping 10from sqlglot.optimizer.annotate_types import TypeAnnotator 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15if t.TYPE_CHECKING: 16 from sqlglot._typing import E 17 18 19def qualify_columns( 20 expression: exp.Expression, 21 schema: t.Dict | Schema, 22 expand_alias_refs: bool = True, 23 expand_stars: bool = True, 24 infer_schema: t.Optional[bool] = None, 25) -> exp.Expression: 26 """ 27 Rewrite sqlglot AST to have fully qualified columns. 28 29 Example: 30 >>> import sqlglot 31 >>> schema = {"tbl": {"col": "INT"}} 32 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 33 >>> qualify_columns(expression, schema).sql() 34 'SELECT tbl.col AS col FROM tbl' 35 36 Args: 37 expression: Expression to qualify. 38 schema: Database schema. 39 expand_alias_refs: Whether to expand references to aliases. 40 expand_stars: Whether to expand star queries. This is a necessary step 41 for most of the optimizer's rules to work; do not set to False unless you 42 know what you're doing! 43 infer_schema: Whether to infer the schema if missing. 44 45 Returns: 46 The qualified expression. 47 48 Notes: 49 - Currently only handles a single PIVOT or UNPIVOT operator 50 """ 51 schema = ensure_schema(schema) 52 annotator = TypeAnnotator(schema) 53 infer_schema = schema.empty if infer_schema is None else infer_schema 54 dialect = Dialect.get_or_raise(schema.dialect) 55 pseudocolumns = dialect.PSEUDOCOLUMNS 56 57 for scope in traverse_scope(expression): 58 resolver = Resolver(scope, schema, infer_schema=infer_schema) 59 _pop_table_column_aliases(scope.ctes) 60 _pop_table_column_aliases(scope.derived_tables) 61 using_column_tables = _expand_using(scope, resolver) 62 63 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 64 _expand_alias_refs( 65 scope, 66 resolver, 67 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 68 ) 69 70 _convert_columns_to_dots(scope, resolver) 71 _qualify_columns(scope, resolver) 72 73 if not schema.empty and expand_alias_refs: 74 _expand_alias_refs(scope, resolver) 75 76 if not isinstance(scope.expression, exp.UDTF): 77 if expand_stars: 78 _expand_stars( 79 scope, 80 resolver, 81 using_column_tables, 82 pseudocolumns, 83 annotator, 84 ) 85 qualify_outputs(scope) 86 87 _expand_group_by(scope, dialect) 88 _expand_order_by(scope, resolver) 89 90 if dialect == "bigquery": 91 annotator.annotate_scope(scope) 92 93 return expression 94 95 96def validate_qualify_columns(expression: E) -> E: 97 """Raise an `OptimizeError` if any columns aren't qualified""" 98 all_unqualified_columns = [] 99 for scope in traverse_scope(expression): 100 if isinstance(scope.expression, exp.Select): 101 unqualified_columns = scope.unqualified_columns 102 103 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 104 column = scope.external_columns[0] 105 for_table = f" for table: '{column.table}'" if column.table else "" 106 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 107 108 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 109 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 110 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 111 # this list here to ensure those in the former category will be excluded. 112 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 113 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 114 115 all_unqualified_columns.extend(unqualified_columns) 116 117 if all_unqualified_columns: 118 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 119 120 return expression 121 122 123def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: 124 name_column = [] 125 field = unpivot.args.get("field") 126 if isinstance(field, exp.In) and isinstance(field.this, exp.Column): 127 name_column.append(field.this) 128 129 value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) 130 return itertools.chain(name_column, value_columns) 131 132 133def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 134 """ 135 Remove table column aliases. 136 137 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 138 """ 139 for derived_table in derived_tables: 140 if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: 141 continue 142 table_alias = derived_table.args.get("alias") 143 if table_alias: 144 table_alias.args.pop("columns", None) 145 146 147def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 148 joins = list(scope.find_all(exp.Join)) 149 names = {join.alias_or_name for join in joins} 150 ordered = [key for key in scope.selected_sources if key not in names] 151 152 # Mapping of automatically joined column names to an ordered set of source names (dict). 153 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 154 155 for i, join in enumerate(joins): 156 using = join.args.get("using") 157 158 if not using: 159 continue 160 161 join_table = join.alias_or_name 162 163 columns = {} 164 165 for source_name in scope.selected_sources: 166 if source_name in ordered: 167 for column_name in resolver.get_source_columns(source_name): 168 if column_name not in columns: 169 columns[column_name] = source_name 170 171 source_table = ordered[-1] 172 ordered.append(join_table) 173 join_columns = resolver.get_source_columns(join_table) 174 conditions = [] 175 using_identifier_count = len(using) 176 177 for identifier in using: 178 identifier = identifier.name 179 table = columns.get(identifier) 180 181 if not table or identifier not in join_columns: 182 if (columns and "*" not in columns) and join_columns: 183 raise OptimizeError(f"Cannot automatically join: {identifier}") 184 185 table = table or source_table 186 187 if i == 0 or using_identifier_count == 1: 188 lhs: exp.Expression = exp.column(identifier, table=table) 189 else: 190 coalesce_columns = [ 191 exp.column(identifier, table=t) 192 for t in ordered[:-1] 193 if identifier in resolver.get_source_columns(t) 194 ] 195 if len(coalesce_columns) > 1: 196 lhs = exp.func("coalesce", *coalesce_columns) 197 else: 198 lhs = exp.column(identifier, table=table) 199 200 conditions.append(lhs.eq(exp.column(identifier, table=join_table))) 201 202 # Set all values in the dict to None, because we only care about the key ordering 203 tables = column_tables.setdefault(identifier, {}) 204 if table not in tables: 205 tables[table] = None 206 if join_table not in tables: 207 tables[join_table] = None 208 209 join.args.pop("using") 210 join.set("on", exp.and_(*conditions, copy=False)) 211 212 if column_tables: 213 for column in scope.columns: 214 if not column.table and column.name in column_tables: 215 tables = column_tables[column.name] 216 coalesce_args = [exp.column(column.name, table=table) for table in tables] 217 replacement = exp.func("coalesce", *coalesce_args) 218 219 # Ensure selects keep their output name 220 if isinstance(column.parent, exp.Select): 221 replacement = alias(replacement, alias=column.name, copy=False) 222 223 scope.replace(column, replacement) 224 225 return column_tables 226 227 228def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None: 229 expression = scope.expression 230 231 if not isinstance(expression, exp.Select): 232 return 233 234 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 235 236 def replace_columns( 237 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 238 ) -> None: 239 if not node or (expand_only_groupby and not isinstance(node, exp.Group)): 240 return 241 242 for column in walk_in_scope(node, prune=lambda node: node.is_star): 243 if not isinstance(column, exp.Column): 244 continue 245 246 table = resolver.get_table(column.name) if resolve_table and not column.table else None 247 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 248 double_agg = ( 249 ( 250 alias_expr.find(exp.AggFunc) 251 and ( 252 column.find_ancestor(exp.AggFunc) 253 and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) 254 ) 255 ) 256 if alias_expr 257 else False 258 ) 259 260 if table and (not alias_expr or double_agg): 261 column.set("table", table) 262 elif not column.table and alias_expr and not double_agg: 263 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 264 if literal_index: 265 column.replace(exp.Literal.number(i)) 266 else: 267 column = column.replace(exp.paren(alias_expr)) 268 simplified = simplify_parens(column) 269 if simplified is not column: 270 column.replace(simplified) 271 272 for i, projection in enumerate(scope.expression.selects): 273 replace_columns(projection) 274 275 if isinstance(projection, exp.Alias): 276 alias_to_expression[projection.alias] = (projection.this, i + 1) 277 278 replace_columns(expression.args.get("where")) 279 replace_columns(expression.args.get("group"), literal_index=True) 280 replace_columns(expression.args.get("having"), resolve_table=True) 281 replace_columns(expression.args.get("qualify"), resolve_table=True) 282 283 scope.clear_cache() 284 285 286def _expand_group_by(scope: Scope, dialect: DialectType) -> None: 287 expression = scope.expression 288 group = expression.args.get("group") 289 if not group: 290 return 291 292 group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) 293 expression.set("group", group) 294 295 296def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 297 order = scope.expression.args.get("order") 298 if not order: 299 return 300 301 ordereds = order.expressions 302 for ordered, new_expression in zip( 303 ordereds, 304 _expand_positional_references( 305 scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True 306 ), 307 ): 308 for agg in ordered.find_all(exp.AggFunc): 309 for col in agg.find_all(exp.Column): 310 if not col.table: 311 col.set("table", resolver.get_table(col.name)) 312 313 ordered.set("this", new_expression) 314 315 if scope.expression.args.get("group"): 316 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 317 318 for ordered in ordereds: 319 ordered = ordered.this 320 321 ordered.replace( 322 exp.to_identifier(_select_by_pos(scope, ordered).alias) 323 if ordered.is_int 324 else selects.get(ordered, ordered) 325 ) 326 327 328def _expand_positional_references( 329 scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False 330) -> t.List[exp.Expression]: 331 new_nodes: t.List[exp.Expression] = [] 332 ambiguous_projections = None 333 334 for node in expressions: 335 if node.is_int: 336 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 337 338 if alias: 339 new_nodes.append(exp.column(select.args["alias"].copy())) 340 else: 341 select = select.this 342 343 if dialect == "bigquery": 344 if ambiguous_projections is None: 345 # When a projection name is also a source name and it is referenced in the 346 # GROUP BY clause, BQ can't understand what the identifier corresponds to 347 ambiguous_projections = { 348 s.alias_or_name 349 for s in scope.expression.selects 350 if s.alias_or_name in scope.selected_sources 351 } 352 353 ambiguous = any( 354 column.parts[0].name in ambiguous_projections 355 for column in select.find_all(exp.Column) 356 ) 357 else: 358 ambiguous = False 359 360 if ( 361 isinstance(select, exp.CONSTANTS) 362 or select.find(exp.Explode, exp.Unnest) 363 or ambiguous 364 ): 365 new_nodes.append(node) 366 else: 367 new_nodes.append(select.copy()) 368 else: 369 new_nodes.append(node) 370 371 return new_nodes 372 373 374def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 375 try: 376 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 377 except IndexError: 378 raise OptimizeError(f"Unknown output column: {node.name}") 379 380 381def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: 382 """ 383 Converts `Column` instances that represent struct field lookup into chained `Dots`. 384 385 Struct field lookups look like columns (e.g. "struct"."field"), but they need to be 386 qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)). 387 """ 388 converted = False 389 for column in itertools.chain(scope.columns, scope.stars): 390 if isinstance(column, exp.Dot): 391 continue 392 393 column_table: t.Optional[str | exp.Identifier] = column.table 394 if ( 395 column_table 396 and column_table not in scope.sources 397 and ( 398 not scope.parent 399 or column_table not in scope.parent.sources 400 or not scope.is_correlated_subquery 401 ) 402 ): 403 root, *parts = column.parts 404 405 if root.name in scope.sources: 406 # The struct is already qualified, but we still need to change the AST 407 column_table = root 408 root, *parts = parts 409 else: 410 column_table = resolver.get_table(root.name) 411 412 if column_table: 413 converted = True 414 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 415 416 if converted: 417 # We want to re-aggregate the converted columns, otherwise they'd be skipped in 418 # a `for column in scope.columns` iteration, even though they shouldn't be 419 scope.clear_cache() 420 421 422def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 423 """Disambiguate columns, ensuring each column specifies a source""" 424 for column in scope.columns: 425 column_table = column.table 426 column_name = column.name 427 428 if column_table and column_table in scope.sources: 429 source_columns = resolver.get_source_columns(column_table) 430 if source_columns and column_name not in source_columns and "*" not in source_columns: 431 raise OptimizeError(f"Unknown column: {column_name}") 432 433 if not column_table: 434 if scope.pivots and not column.find_ancestor(exp.Pivot): 435 # If the column is under the Pivot expression, we need to qualify it 436 # using the name of the pivoted source instead of the pivot's alias 437 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 438 continue 439 440 # column_table can be a '' because bigquery unnest has no table alias 441 column_table = resolver.get_table(column_name) 442 if column_table: 443 column.set("table", column_table) 444 445 for pivot in scope.pivots: 446 for column in pivot.find_all(exp.Column): 447 if not column.table and column.name in resolver.all_columns: 448 column_table = resolver.get_table(column.name) 449 if column_table: 450 column.set("table", column_table) 451 452 453def _expand_struct_stars( 454 expression: exp.Dot, 455) -> t.List[exp.Alias]: 456 """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" 457 458 dot_column = t.cast(exp.Column, expression.find(exp.Column)) 459 if not dot_column.is_type(exp.DataType.Type.STRUCT): 460 return [] 461 462 # All nested struct values are ColumnDefs, so normalize the first exp.Column in one 463 dot_column = dot_column.copy() 464 starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) 465 466 # First part is the table name and last part is the star so they can be dropped 467 dot_parts = expression.parts[1:-1] 468 469 # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) 470 for part in dot_parts[1:]: 471 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 472 # Unable to expand star unless all fields are named 473 if not isinstance(field.this, exp.Identifier): 474 return [] 475 476 if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): 477 starting_struct = field 478 break 479 else: 480 # There is no matching field in the struct 481 return [] 482 483 taken_names = set() 484 new_selections = [] 485 486 for field in t.cast(exp.DataType, starting_struct.kind).expressions: 487 name = field.name 488 489 # Ambiguous or anonymous fields can't be expanded 490 if name in taken_names or not isinstance(field.this, exp.Identifier): 491 return [] 492 493 taken_names.add(name) 494 495 this = field.this.copy() 496 root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] 497 new_column = exp.column( 498 t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts 499 ) 500 new_selections.append(alias(new_column, this, copy=False)) 501 502 return new_selections 503 504 505def _expand_stars( 506 scope: Scope, 507 resolver: Resolver, 508 using_column_tables: t.Dict[str, t.Any], 509 pseudocolumns: t.Set[str], 510 annotator: TypeAnnotator, 511) -> None: 512 """Expand stars to lists of column selections""" 513 514 new_selections = [] 515 except_columns: t.Dict[int, t.Set[str]] = {} 516 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 517 coalesced_columns = set() 518 dialect = resolver.schema.dialect 519 520 pivot_output_columns = None 521 pivot_exclude_columns = None 522 523 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 524 if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: 525 if pivot.unpivot: 526 pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] 527 528 field = pivot.args.get("field") 529 if isinstance(field, exp.In): 530 pivot_exclude_columns = { 531 c.output_name for e in field.expressions for c in e.find_all(exp.Column) 532 } 533 else: 534 pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) 535 536 pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] 537 if not pivot_output_columns: 538 pivot_output_columns = [c.alias_or_name for c in pivot.expressions] 539 540 is_bigquery = dialect == "bigquery" 541 if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): 542 # Found struct expansion, annotate scope ahead of time 543 annotator.annotate_scope(scope) 544 545 for expression in scope.expression.selects: 546 tables = [] 547 if isinstance(expression, exp.Star): 548 tables.extend(scope.selected_sources) 549 _add_except_columns(expression, tables, except_columns) 550 _add_replace_columns(expression, tables, replace_columns) 551 elif expression.is_star: 552 if not isinstance(expression, exp.Dot): 553 tables.append(expression.table) 554 _add_except_columns(expression.this, tables, except_columns) 555 _add_replace_columns(expression.this, tables, replace_columns) 556 elif is_bigquery: 557 struct_fields = _expand_struct_stars(expression) 558 if struct_fields: 559 new_selections.extend(struct_fields) 560 continue 561 562 if not tables: 563 new_selections.append(expression) 564 continue 565 566 for table in tables: 567 if table not in scope.sources: 568 raise OptimizeError(f"Unknown table: {table}") 569 570 columns = resolver.get_source_columns(table, only_visible=True) 571 columns = columns or scope.outer_columns 572 573 if pseudocolumns: 574 columns = [name for name in columns if name.upper() not in pseudocolumns] 575 576 if not columns or "*" in columns: 577 return 578 579 table_id = id(table) 580 columns_to_exclude = except_columns.get(table_id) or set() 581 582 if pivot: 583 if pivot_output_columns and pivot_exclude_columns: 584 pivot_columns = [c for c in columns if c not in pivot_exclude_columns] 585 pivot_columns.extend(pivot_output_columns) 586 else: 587 pivot_columns = pivot.alias_column_names 588 589 if pivot_columns: 590 new_selections.extend( 591 alias(exp.column(name, table=pivot.alias), name, copy=False) 592 for name in pivot_columns 593 if name not in columns_to_exclude 594 ) 595 continue 596 597 for name in columns: 598 if name in columns_to_exclude or name in coalesced_columns: 599 continue 600 if name in using_column_tables and table in using_column_tables[name]: 601 coalesced_columns.add(name) 602 tables = using_column_tables[name] 603 coalesce_args = [exp.column(name, table=table) for table in tables] 604 605 new_selections.append( 606 alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) 607 ) 608 else: 609 alias_ = replace_columns.get(table_id, {}).get(name, name) 610 column = exp.column(name, table=table) 611 new_selections.append( 612 alias(column, alias_, copy=False) if alias_ != name else column 613 ) 614 615 # Ensures we don't overwrite the initial selections with an empty list 616 if new_selections and isinstance(scope.expression, exp.Select): 617 scope.expression.set("expressions", new_selections) 618 619 620def _add_except_columns( 621 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 622) -> None: 623 except_ = expression.args.get("except") 624 625 if not except_: 626 return 627 628 columns = {e.name for e in except_} 629 630 for table in tables: 631 except_columns[id(table)] = columns 632 633 634def _add_replace_columns( 635 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 636) -> None: 637 replace = expression.args.get("replace") 638 639 if not replace: 640 return 641 642 columns = {e.this.name: e.alias for e in replace} 643 644 for table in tables: 645 replace_columns[id(table)] = columns 646 647 648def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 649 """Ensure all output columns are aliased""" 650 if isinstance(scope_or_expression, exp.Expression): 651 scope = build_scope(scope_or_expression) 652 if not isinstance(scope, Scope): 653 return 654 else: 655 scope = scope_or_expression 656 657 new_selections = [] 658 for i, (selection, aliased_column) in enumerate( 659 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 660 ): 661 if selection is None: 662 break 663 664 if isinstance(selection, exp.Subquery): 665 if not selection.output_name: 666 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 667 elif not isinstance(selection, exp.Alias) and not selection.is_star: 668 selection = alias( 669 selection, 670 alias=selection.output_name or f"_col_{i}", 671 copy=False, 672 ) 673 if aliased_column: 674 selection.set("alias", exp.to_identifier(aliased_column)) 675 676 new_selections.append(selection) 677 678 if isinstance(scope.expression, exp.Select): 679 scope.expression.set("expressions", new_selections) 680 681 682def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 683 """Makes sure all identifiers that need to be quoted are quoted.""" 684 return expression.transform( 685 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 686 ) # type: ignore 687 688 689def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 690 """ 691 Pushes down the CTE alias columns into the projection, 692 693 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 694 695 Example: 696 >>> import sqlglot 697 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 698 >>> pushdown_cte_alias_columns(expression).sql() 699 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 700 701 Args: 702 expression: Expression to pushdown. 703 704 Returns: 705 The expression with the CTE aliases pushed down into the projection. 706 """ 707 for cte in expression.find_all(exp.CTE): 708 if cte.alias_column_names: 709 new_expressions = [] 710 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 711 if isinstance(projection, exp.Alias): 712 projection.set("alias", _alias) 713 else: 714 projection = alias(projection, alias=_alias) 715 new_expressions.append(projection) 716 cte.this.set("expressions", new_expressions) 717 718 return expression 719 720 721class Resolver: 722 """ 723 Helper for resolving columns. 724 725 This is a class so we can lazily load some things and easily share them across functions. 726 """ 727 728 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 729 self.scope = scope 730 self.schema = schema 731 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 732 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 733 self._all_columns: t.Optional[t.Set[str]] = None 734 self._infer_schema = infer_schema 735 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 736 737 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 738 """ 739 Get the table for a column name. 740 741 Args: 742 column_name: The column name to find the table for. 743 Returns: 744 The table name if it can be found/inferred. 745 """ 746 if self._unambiguous_columns is None: 747 self._unambiguous_columns = self._get_unambiguous_columns( 748 self._get_all_source_columns() 749 ) 750 751 table_name = self._unambiguous_columns.get(column_name) 752 753 if not table_name and self._infer_schema: 754 sources_without_schema = tuple( 755 source 756 for source, columns in self._get_all_source_columns().items() 757 if not columns or "*" in columns 758 ) 759 if len(sources_without_schema) == 1: 760 table_name = sources_without_schema[0] 761 762 if table_name not in self.scope.selected_sources: 763 return exp.to_identifier(table_name) 764 765 node, _ = self.scope.selected_sources.get(table_name) 766 767 if isinstance(node, exp.Query): 768 while node and node.alias != table_name: 769 node = node.parent 770 771 node_alias = node.args.get("alias") 772 if node_alias: 773 return exp.to_identifier(node_alias.this) 774 775 return exp.to_identifier(table_name) 776 777 @property 778 def all_columns(self) -> t.Set[str]: 779 """All available columns of all sources in this scope""" 780 if self._all_columns is None: 781 self._all_columns = { 782 column for columns in self._get_all_source_columns().values() for column in columns 783 } 784 return self._all_columns 785 786 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 787 """Resolve the source columns for a given source `name`.""" 788 cache_key = (name, only_visible) 789 if cache_key not in self._get_source_columns_cache: 790 if name not in self.scope.sources: 791 raise OptimizeError(f"Unknown table: {name}") 792 793 source = self.scope.sources[name] 794 795 if isinstance(source, exp.Table): 796 columns = self.schema.column_names(source, only_visible) 797 elif isinstance(source, Scope) and isinstance( 798 source.expression, (exp.Values, exp.Unnest) 799 ): 800 columns = source.expression.named_selects 801 802 # in bigquery, unnest structs are automatically scoped as tables, so you can 803 # directly select a struct field in a query. 804 # this handles the case where the unnest is statically defined. 805 if self.schema.dialect == "bigquery": 806 if source.expression.is_type(exp.DataType.Type.STRUCT): 807 for k in source.expression.type.expressions: # type: ignore 808 columns.append(k.name) 809 else: 810 columns = source.expression.named_selects 811 812 node, _ = self.scope.selected_sources.get(name) or (None, None) 813 if isinstance(node, Scope): 814 column_aliases = node.expression.alias_column_names 815 elif isinstance(node, exp.Expression): 816 column_aliases = node.alias_column_names 817 else: 818 column_aliases = [] 819 820 if column_aliases: 821 # If the source's columns are aliased, their aliases shadow the corresponding column names. 822 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 823 columns = [ 824 alias or name 825 for (name, alias) in itertools.zip_longest(columns, column_aliases) 826 ] 827 828 self._get_source_columns_cache[cache_key] = columns 829 830 return self._get_source_columns_cache[cache_key] 831 832 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 833 if self._source_columns is None: 834 self._source_columns = { 835 source_name: self.get_source_columns(source_name) 836 for source_name, source in itertools.chain( 837 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 838 ) 839 } 840 return self._source_columns 841 842 def _get_unambiguous_columns( 843 self, source_columns: t.Dict[str, t.Sequence[str]] 844 ) -> t.Mapping[str, str]: 845 """ 846 Find all the unambiguous columns in sources. 847 848 Args: 849 source_columns: Mapping of names to source columns. 850 851 Returns: 852 Mapping of column name to source name. 853 """ 854 if not source_columns: 855 return {} 856 857 source_columns_pairs = list(source_columns.items()) 858 859 first_table, first_columns = source_columns_pairs[0] 860 861 if len(source_columns_pairs) == 1: 862 # Performance optimization - avoid copying first_columns if there is only one table. 863 return SingleValuedMapping(first_columns, first_table) 864 865 unambiguous_columns = {col: first_table for col in first_columns} 866 all_columns = set(unambiguous_columns) 867 868 for table, columns in source_columns_pairs[1:]: 869 unique = set(columns) 870 ambiguous = all_columns.intersection(unique) 871 all_columns.update(columns) 872 873 for column in ambiguous: 874 unambiguous_columns.pop(column, None) 875 for column in unique.difference(ambiguous): 876 unambiguous_columns[column] = table 877 878 return unambiguous_columns
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns( 21 expression: exp.Expression, 22 schema: t.Dict | Schema, 23 expand_alias_refs: bool = True, 24 expand_stars: bool = True, 25 infer_schema: t.Optional[bool] = None, 26) -> exp.Expression: 27 """ 28 Rewrite sqlglot AST to have fully qualified columns. 29 30 Example: 31 >>> import sqlglot 32 >>> schema = {"tbl": {"col": "INT"}} 33 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 34 >>> qualify_columns(expression, schema).sql() 35 'SELECT tbl.col AS col FROM tbl' 36 37 Args: 38 expression: Expression to qualify. 39 schema: Database schema. 40 expand_alias_refs: Whether to expand references to aliases. 41 expand_stars: Whether to expand star queries. This is a necessary step 42 for most of the optimizer's rules to work; do not set to False unless you 43 know what you're doing! 44 infer_schema: Whether to infer the schema if missing. 45 46 Returns: 47 The qualified expression. 48 49 Notes: 50 - Currently only handles a single PIVOT or UNPIVOT operator 51 """ 52 schema = ensure_schema(schema) 53 annotator = TypeAnnotator(schema) 54 infer_schema = schema.empty if infer_schema is None else infer_schema 55 dialect = Dialect.get_or_raise(schema.dialect) 56 pseudocolumns = dialect.PSEUDOCOLUMNS 57 58 for scope in traverse_scope(expression): 59 resolver = Resolver(scope, schema, infer_schema=infer_schema) 60 _pop_table_column_aliases(scope.ctes) 61 _pop_table_column_aliases(scope.derived_tables) 62 using_column_tables = _expand_using(scope, resolver) 63 64 if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: 65 _expand_alias_refs( 66 scope, 67 resolver, 68 expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY, 69 ) 70 71 _convert_columns_to_dots(scope, resolver) 72 _qualify_columns(scope, resolver) 73 74 if not schema.empty and expand_alias_refs: 75 _expand_alias_refs(scope, resolver) 76 77 if not isinstance(scope.expression, exp.UDTF): 78 if expand_stars: 79 _expand_stars( 80 scope, 81 resolver, 82 using_column_tables, 83 pseudocolumns, 84 annotator, 85 ) 86 qualify_outputs(scope) 87 88 _expand_group_by(scope, dialect) 89 _expand_order_by(scope, resolver) 90 91 if dialect == "bigquery": 92 annotator.annotate_scope(scope) 93 94 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether to expand references to aliases.
- expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
- infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
Notes:
- Currently only handles a single PIVOT or UNPIVOT operator
def
validate_qualify_columns(expression: ~E) -> ~E:
97def validate_qualify_columns(expression: E) -> E: 98 """Raise an `OptimizeError` if any columns aren't qualified""" 99 all_unqualified_columns = [] 100 for scope in traverse_scope(expression): 101 if isinstance(scope.expression, exp.Select): 102 unqualified_columns = scope.unqualified_columns 103 104 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 105 column = scope.external_columns[0] 106 for_table = f" for table: '{column.table}'" if column.table else "" 107 raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") 108 109 if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: 110 # New columns produced by the UNPIVOT can't be qualified, but there may be columns 111 # under the UNPIVOT's IN clause that can and should be qualified. We recompute 112 # this list here to ensure those in the former category will be excluded. 113 unpivot_columns = set(_unpivot_columns(scope.pivots[0])) 114 unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] 115 116 all_unqualified_columns.extend(unqualified_columns) 117 118 if all_unqualified_columns: 119 raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") 120 121 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
649def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 650 """Ensure all output columns are aliased""" 651 if isinstance(scope_or_expression, exp.Expression): 652 scope = build_scope(scope_or_expression) 653 if not isinstance(scope, Scope): 654 return 655 else: 656 scope = scope_or_expression 657 658 new_selections = [] 659 for i, (selection, aliased_column) in enumerate( 660 itertools.zip_longest(scope.expression.selects, scope.outer_columns) 661 ): 662 if selection is None: 663 break 664 665 if isinstance(selection, exp.Subquery): 666 if not selection.output_name: 667 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 668 elif not isinstance(selection, exp.Alias) and not selection.is_star: 669 selection = alias( 670 selection, 671 alias=selection.output_name or f"_col_{i}", 672 copy=False, 673 ) 674 if aliased_column: 675 selection.set("alias", exp.to_identifier(aliased_column)) 676 677 new_selections.append(selection) 678 679 if isinstance(scope.expression, exp.Select): 680 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
683def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 684 """Makes sure all identifiers that need to be quoted are quoted.""" 685 return expression.transform( 686 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 687 ) # type: ignore
Makes sure all identifiers that need to be quoted are quoted.
def
pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
690def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: 691 """ 692 Pushes down the CTE alias columns into the projection, 693 694 This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. 695 696 Example: 697 >>> import sqlglot 698 >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") 699 >>> pushdown_cte_alias_columns(expression).sql() 700 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' 701 702 Args: 703 expression: Expression to pushdown. 704 705 Returns: 706 The expression with the CTE aliases pushed down into the projection. 707 """ 708 for cte in expression.find_all(exp.CTE): 709 if cte.alias_column_names: 710 new_expressions = [] 711 for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): 712 if isinstance(projection, exp.Alias): 713 projection.set("alias", _alias) 714 else: 715 projection = alias(projection, alias=_alias) 716 new_expressions.append(projection) 717 cte.this.set("expressions", new_expressions) 718 719 return expression
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") >>> pushdown_cte_alias_columns(expression).sql() 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
- expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
class
Resolver:
722class Resolver: 723 """ 724 Helper for resolving columns. 725 726 This is a class so we can lazily load some things and easily share them across functions. 727 """ 728 729 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 730 self.scope = scope 731 self.schema = schema 732 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 733 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 734 self._all_columns: t.Optional[t.Set[str]] = None 735 self._infer_schema = infer_schema 736 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} 737 738 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 739 """ 740 Get the table for a column name. 741 742 Args: 743 column_name: The column name to find the table for. 744 Returns: 745 The table name if it can be found/inferred. 746 """ 747 if self._unambiguous_columns is None: 748 self._unambiguous_columns = self._get_unambiguous_columns( 749 self._get_all_source_columns() 750 ) 751 752 table_name = self._unambiguous_columns.get(column_name) 753 754 if not table_name and self._infer_schema: 755 sources_without_schema = tuple( 756 source 757 for source, columns in self._get_all_source_columns().items() 758 if not columns or "*" in columns 759 ) 760 if len(sources_without_schema) == 1: 761 table_name = sources_without_schema[0] 762 763 if table_name not in self.scope.selected_sources: 764 return exp.to_identifier(table_name) 765 766 node, _ = self.scope.selected_sources.get(table_name) 767 768 if isinstance(node, exp.Query): 769 while node and node.alias != table_name: 770 node = node.parent 771 772 node_alias = node.args.get("alias") 773 if node_alias: 774 return exp.to_identifier(node_alias.this) 775 776 return exp.to_identifier(table_name) 777 778 @property 779 def all_columns(self) -> t.Set[str]: 780 """All available columns of all sources in this scope""" 781 if self._all_columns is None: 782 self._all_columns = { 783 column for columns in self._get_all_source_columns().values() for column in columns 784 } 785 return self._all_columns 786 787 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 788 """Resolve the source columns for a given source `name`.""" 789 cache_key = (name, only_visible) 790 if cache_key not in self._get_source_columns_cache: 791 if name not in self.scope.sources: 792 raise OptimizeError(f"Unknown table: {name}") 793 794 source = self.scope.sources[name] 795 796 if isinstance(source, exp.Table): 797 columns = self.schema.column_names(source, only_visible) 798 elif isinstance(source, Scope) and isinstance( 799 source.expression, (exp.Values, exp.Unnest) 800 ): 801 columns = source.expression.named_selects 802 803 # in bigquery, unnest structs are automatically scoped as tables, so you can 804 # directly select a struct field in a query. 805 # this handles the case where the unnest is statically defined. 806 if self.schema.dialect == "bigquery": 807 if source.expression.is_type(exp.DataType.Type.STRUCT): 808 for k in source.expression.type.expressions: # type: ignore 809 columns.append(k.name) 810 else: 811 columns = source.expression.named_selects 812 813 node, _ = self.scope.selected_sources.get(name) or (None, None) 814 if isinstance(node, Scope): 815 column_aliases = node.expression.alias_column_names 816 elif isinstance(node, exp.Expression): 817 column_aliases = node.alias_column_names 818 else: 819 column_aliases = [] 820 821 if column_aliases: 822 # If the source's columns are aliased, their aliases shadow the corresponding column names. 823 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 824 columns = [ 825 alias or name 826 for (name, alias) in itertools.zip_longest(columns, column_aliases) 827 ] 828 829 self._get_source_columns_cache[cache_key] = columns 830 831 return self._get_source_columns_cache[cache_key] 832 833 def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: 834 if self._source_columns is None: 835 self._source_columns = { 836 source_name: self.get_source_columns(source_name) 837 for source_name, source in itertools.chain( 838 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 839 ) 840 } 841 return self._source_columns 842 843 def _get_unambiguous_columns( 844 self, source_columns: t.Dict[str, t.Sequence[str]] 845 ) -> t.Mapping[str, str]: 846 """ 847 Find all the unambiguous columns in sources. 848 849 Args: 850 source_columns: Mapping of names to source columns. 851 852 Returns: 853 Mapping of column name to source name. 854 """ 855 if not source_columns: 856 return {} 857 858 source_columns_pairs = list(source_columns.items()) 859 860 first_table, first_columns = source_columns_pairs[0] 861 862 if len(source_columns_pairs) == 1: 863 # Performance optimization - avoid copying first_columns if there is only one table. 864 return SingleValuedMapping(first_columns, first_table) 865 866 unambiguous_columns = {col: first_table for col in first_columns} 867 all_columns = set(unambiguous_columns) 868 869 for table, columns in source_columns_pairs[1:]: 870 unique = set(columns) 871 ambiguous = all_columns.intersection(unique) 872 all_columns.update(columns) 873 874 for column in ambiguous: 875 unambiguous_columns.pop(column, None) 876 for column in unique.difference(ambiguous): 877 unambiguous_columns[column] = table 878 879 return unambiguous_columns
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
729 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 730 self.scope = scope 731 self.schema = schema 732 self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None 733 self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None 734 self._all_columns: t.Optional[t.Set[str]] = None 735 self._infer_schema = infer_schema 736 self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
738 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 739 """ 740 Get the table for a column name. 741 742 Args: 743 column_name: The column name to find the table for. 744 Returns: 745 The table name if it can be found/inferred. 746 """ 747 if self._unambiguous_columns is None: 748 self._unambiguous_columns = self._get_unambiguous_columns( 749 self._get_all_source_columns() 750 ) 751 752 table_name = self._unambiguous_columns.get(column_name) 753 754 if not table_name and self._infer_schema: 755 sources_without_schema = tuple( 756 source 757 for source, columns in self._get_all_source_columns().items() 758 if not columns or "*" in columns 759 ) 760 if len(sources_without_schema) == 1: 761 table_name = sources_without_schema[0] 762 763 if table_name not in self.scope.selected_sources: 764 return exp.to_identifier(table_name) 765 766 node, _ = self.scope.selected_sources.get(table_name) 767 768 if isinstance(node, exp.Query): 769 while node and node.alias != table_name: 770 node = node.parent 771 772 node_alias = node.args.get("alias") 773 if node_alias: 774 return exp.to_identifier(node_alias.this) 775 776 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
all_columns: Set[str]
778 @property 779 def all_columns(self) -> t.Set[str]: 780 """All available columns of all sources in this scope""" 781 if self._all_columns is None: 782 self._all_columns = { 783 column for columns in self._get_all_source_columns().values() for column in columns 784 } 785 return self._all_columns
All available columns of all sources in this scope
def
get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
787 def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: 788 """Resolve the source columns for a given source `name`.""" 789 cache_key = (name, only_visible) 790 if cache_key not in self._get_source_columns_cache: 791 if name not in self.scope.sources: 792 raise OptimizeError(f"Unknown table: {name}") 793 794 source = self.scope.sources[name] 795 796 if isinstance(source, exp.Table): 797 columns = self.schema.column_names(source, only_visible) 798 elif isinstance(source, Scope) and isinstance( 799 source.expression, (exp.Values, exp.Unnest) 800 ): 801 columns = source.expression.named_selects 802 803 # in bigquery, unnest structs are automatically scoped as tables, so you can 804 # directly select a struct field in a query. 805 # this handles the case where the unnest is statically defined. 806 if self.schema.dialect == "bigquery": 807 if source.expression.is_type(exp.DataType.Type.STRUCT): 808 for k in source.expression.type.expressions: # type: ignore 809 columns.append(k.name) 810 else: 811 columns = source.expression.named_selects 812 813 node, _ = self.scope.selected_sources.get(name) or (None, None) 814 if isinstance(node, Scope): 815 column_aliases = node.expression.alias_column_names 816 elif isinstance(node, exp.Expression): 817 column_aliases = node.alias_column_names 818 else: 819 column_aliases = [] 820 821 if column_aliases: 822 # If the source's columns are aliased, their aliases shadow the corresponding column names. 823 # This can be expensive if there are lots of columns, so only do this if column_aliases exist. 824 columns = [ 825 alias or name 826 for (name, alias) in itertools.zip_longest(columns, column_aliases) 827 ] 828 829 self._get_source_columns_cache[cache_key] = columns 830 831 return self._get_source_columns_cache[cache_key]
Resolve the source columns for a given source name
.