Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.annotate_types import TypeAnnotator
 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot._typing import E
 17
 18
 19def qualify_columns(
 20    expression: exp.Expression,
 21    schema: t.Dict | Schema,
 22    expand_alias_refs: bool = True,
 23    expand_stars: bool = True,
 24    infer_schema: t.Optional[bool] = None,
 25) -> exp.Expression:
 26    """
 27    Rewrite sqlglot AST to have fully qualified columns.
 28
 29    Example:
 30        >>> import sqlglot
 31        >>> schema = {"tbl": {"col": "INT"}}
 32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 33        >>> qualify_columns(expression, schema).sql()
 34        'SELECT tbl.col AS col FROM tbl'
 35
 36    Args:
 37        expression: Expression to qualify.
 38        schema: Database schema.
 39        expand_alias_refs: Whether to expand references to aliases.
 40        expand_stars: Whether to expand star queries. This is a necessary step
 41            for most of the optimizer's rules to work; do not set to False unless you
 42            know what you're doing!
 43        infer_schema: Whether to infer the schema if missing.
 44
 45    Returns:
 46        The qualified expression.
 47
 48    Notes:
 49        - Currently only handles a single PIVOT or UNPIVOT operator
 50    """
 51    schema = ensure_schema(schema)
 52    annotator = TypeAnnotator(schema)
 53    infer_schema = schema.empty if infer_schema is None else infer_schema
 54    dialect = Dialect.get_or_raise(schema.dialect)
 55    pseudocolumns = dialect.PSEUDOCOLUMNS
 56
 57    for scope in traverse_scope(expression):
 58        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 59        _pop_table_column_aliases(scope.ctes)
 60        _pop_table_column_aliases(scope.derived_tables)
 61        using_column_tables = _expand_using(scope, resolver)
 62
 63        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
 64            _expand_alias_refs(
 65                scope,
 66                resolver,
 67                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
 68            )
 69
 70        _convert_columns_to_dots(scope, resolver)
 71        _qualify_columns(scope, resolver)
 72
 73        if not schema.empty and expand_alias_refs:
 74            _expand_alias_refs(scope, resolver)
 75
 76        if not isinstance(scope.expression, exp.UDTF):
 77            if expand_stars:
 78                _expand_stars(
 79                    scope,
 80                    resolver,
 81                    using_column_tables,
 82                    pseudocolumns,
 83                    annotator,
 84                )
 85            qualify_outputs(scope)
 86
 87        _expand_group_by(scope, dialect)
 88        _expand_order_by(scope, resolver)
 89
 90        if dialect == "bigquery":
 91            annotator.annotate_scope(scope)
 92
 93    return expression
 94
 95
 96def validate_qualify_columns(expression: E) -> E:
 97    """Raise an `OptimizeError` if any columns aren't qualified"""
 98    all_unqualified_columns = []
 99    for scope in traverse_scope(expression):
100        if isinstance(scope.expression, exp.Select):
101            unqualified_columns = scope.unqualified_columns
102
103            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
104                column = scope.external_columns[0]
105                for_table = f" for table: '{column.table}'" if column.table else ""
106                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
107
108            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
109                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
110                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
111                # this list here to ensure those in the former category will be excluded.
112                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
113                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
114
115            all_unqualified_columns.extend(unqualified_columns)
116
117    if all_unqualified_columns:
118        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
119
120    return expression
121
122
123def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
124    name_column = []
125    field = unpivot.args.get("field")
126    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
127        name_column.append(field.this)
128
129    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
130    return itertools.chain(name_column, value_columns)
131
132
133def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
134    """
135    Remove table column aliases.
136
137    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
138    """
139    for derived_table in derived_tables:
140        if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive:
141            continue
142        table_alias = derived_table.args.get("alias")
143        if table_alias:
144            table_alias.args.pop("columns", None)
145
146
147def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
148    joins = list(scope.find_all(exp.Join))
149    names = {join.alias_or_name for join in joins}
150    ordered = [key for key in scope.selected_sources if key not in names]
151
152    # Mapping of automatically joined column names to an ordered set of source names (dict).
153    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
154
155    for i, join in enumerate(joins):
156        using = join.args.get("using")
157
158        if not using:
159            continue
160
161        join_table = join.alias_or_name
162
163        columns = {}
164
165        for source_name in scope.selected_sources:
166            if source_name in ordered:
167                for column_name in resolver.get_source_columns(source_name):
168                    if column_name not in columns:
169                        columns[column_name] = source_name
170
171        source_table = ordered[-1]
172        ordered.append(join_table)
173        join_columns = resolver.get_source_columns(join_table)
174        conditions = []
175        using_identifier_count = len(using)
176
177        for identifier in using:
178            identifier = identifier.name
179            table = columns.get(identifier)
180
181            if not table or identifier not in join_columns:
182                if (columns and "*" not in columns) and join_columns:
183                    raise OptimizeError(f"Cannot automatically join: {identifier}")
184
185            table = table or source_table
186
187            if i == 0 or using_identifier_count == 1:
188                lhs: exp.Expression = exp.column(identifier, table=table)
189            else:
190                coalesce_columns = [
191                    exp.column(identifier, table=t)
192                    for t in ordered[:-1]
193                    if identifier in resolver.get_source_columns(t)
194                ]
195                if len(coalesce_columns) > 1:
196                    lhs = exp.func("coalesce", *coalesce_columns)
197                else:
198                    lhs = exp.column(identifier, table=table)
199
200            conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
201
202            # Set all values in the dict to None, because we only care about the key ordering
203            tables = column_tables.setdefault(identifier, {})
204            if table not in tables:
205                tables[table] = None
206            if join_table not in tables:
207                tables[join_table] = None
208
209        join.args.pop("using")
210        join.set("on", exp.and_(*conditions, copy=False))
211
212    if column_tables:
213        for column in scope.columns:
214            if not column.table and column.name in column_tables:
215                tables = column_tables[column.name]
216                coalesce_args = [exp.column(column.name, table=table) for table in tables]
217                replacement = exp.func("coalesce", *coalesce_args)
218
219                # Ensure selects keep their output name
220                if isinstance(column.parent, exp.Select):
221                    replacement = alias(replacement, alias=column.name, copy=False)
222
223                scope.replace(column, replacement)
224
225    return column_tables
226
227
228def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
229    expression = scope.expression
230
231    if not isinstance(expression, exp.Select):
232        return
233
234    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
235
236    def replace_columns(
237        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
238    ) -> None:
239        if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
240            return
241
242        for column in walk_in_scope(node, prune=lambda node: node.is_star):
243            if not isinstance(column, exp.Column):
244                continue
245
246            table = resolver.get_table(column.name) if resolve_table and not column.table else None
247            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
248            double_agg = (
249                (
250                    alias_expr.find(exp.AggFunc)
251                    and (
252                        column.find_ancestor(exp.AggFunc)
253                        and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window)
254                    )
255                )
256                if alias_expr
257                else False
258            )
259
260            if table and (not alias_expr or double_agg):
261                column.set("table", table)
262            elif not column.table and alias_expr and not double_agg:
263                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
264                    if literal_index:
265                        column.replace(exp.Literal.number(i))
266                else:
267                    column = column.replace(exp.paren(alias_expr))
268                    simplified = simplify_parens(column)
269                    if simplified is not column:
270                        column.replace(simplified)
271
272    for i, projection in enumerate(scope.expression.selects):
273        replace_columns(projection)
274
275        if isinstance(projection, exp.Alias):
276            alias_to_expression[projection.alias] = (projection.this, i + 1)
277
278    replace_columns(expression.args.get("where"))
279    replace_columns(expression.args.get("group"), literal_index=True)
280    replace_columns(expression.args.get("having"), resolve_table=True)
281    replace_columns(expression.args.get("qualify"), resolve_table=True)
282
283    scope.clear_cache()
284
285
286def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
287    expression = scope.expression
288    group = expression.args.get("group")
289    if not group:
290        return
291
292    group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
293    expression.set("group", group)
294
295
296def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
297    order = scope.expression.args.get("order")
298    if not order:
299        return
300
301    ordereds = order.expressions
302    for ordered, new_expression in zip(
303        ordereds,
304        _expand_positional_references(
305            scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
306        ),
307    ):
308        for agg in ordered.find_all(exp.AggFunc):
309            for col in agg.find_all(exp.Column):
310                if not col.table:
311                    col.set("table", resolver.get_table(col.name))
312
313        ordered.set("this", new_expression)
314
315    if scope.expression.args.get("group"):
316        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
317
318        for ordered in ordereds:
319            ordered = ordered.this
320
321            ordered.replace(
322                exp.to_identifier(_select_by_pos(scope, ordered).alias)
323                if ordered.is_int
324                else selects.get(ordered, ordered)
325            )
326
327
328def _expand_positional_references(
329    scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
330) -> t.List[exp.Expression]:
331    new_nodes: t.List[exp.Expression] = []
332    ambiguous_projections = None
333
334    for node in expressions:
335        if node.is_int:
336            select = _select_by_pos(scope, t.cast(exp.Literal, node))
337
338            if alias:
339                new_nodes.append(exp.column(select.args["alias"].copy()))
340            else:
341                select = select.this
342
343                if dialect == "bigquery":
344                    if ambiguous_projections is None:
345                        # When a projection name is also a source name and it is referenced in the
346                        # GROUP BY clause, BQ can't understand what the identifier corresponds to
347                        ambiguous_projections = {
348                            s.alias_or_name
349                            for s in scope.expression.selects
350                            if s.alias_or_name in scope.selected_sources
351                        }
352
353                    ambiguous = any(
354                        column.parts[0].name in ambiguous_projections
355                        for column in select.find_all(exp.Column)
356                    )
357                else:
358                    ambiguous = False
359
360                if (
361                    isinstance(select, exp.CONSTANTS)
362                    or select.find(exp.Explode, exp.Unnest)
363                    or ambiguous
364                ):
365                    new_nodes.append(node)
366                else:
367                    new_nodes.append(select.copy())
368        else:
369            new_nodes.append(node)
370
371    return new_nodes
372
373
374def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
375    try:
376        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
377    except IndexError:
378        raise OptimizeError(f"Unknown output column: {node.name}")
379
380
381def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
382    """
383    Converts `Column` instances that represent struct field lookup into chained `Dots`.
384
385    Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
386    qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
387    """
388    converted = False
389    for column in itertools.chain(scope.columns, scope.stars):
390        if isinstance(column, exp.Dot):
391            continue
392
393        column_table: t.Optional[str | exp.Identifier] = column.table
394        if (
395            column_table
396            and column_table not in scope.sources
397            and (
398                not scope.parent
399                or column_table not in scope.parent.sources
400                or not scope.is_correlated_subquery
401            )
402        ):
403            root, *parts = column.parts
404
405            if root.name in scope.sources:
406                # The struct is already qualified, but we still need to change the AST
407                column_table = root
408                root, *parts = parts
409            else:
410                column_table = resolver.get_table(root.name)
411
412            if column_table:
413                converted = True
414                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
415
416    if converted:
417        # We want to re-aggregate the converted columns, otherwise they'd be skipped in
418        # a `for column in scope.columns` iteration, even though they shouldn't be
419        scope.clear_cache()
420
421
422def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
423    """Disambiguate columns, ensuring each column specifies a source"""
424    for column in scope.columns:
425        column_table = column.table
426        column_name = column.name
427
428        if column_table and column_table in scope.sources:
429            source_columns = resolver.get_source_columns(column_table)
430            if source_columns and column_name not in source_columns and "*" not in source_columns:
431                raise OptimizeError(f"Unknown column: {column_name}")
432
433        if not column_table:
434            if scope.pivots and not column.find_ancestor(exp.Pivot):
435                # If the column is under the Pivot expression, we need to qualify it
436                # using the name of the pivoted source instead of the pivot's alias
437                column.set("table", exp.to_identifier(scope.pivots[0].alias))
438                continue
439
440            # column_table can be a '' because bigquery unnest has no table alias
441            column_table = resolver.get_table(column_name)
442            if column_table:
443                column.set("table", column_table)
444
445    for pivot in scope.pivots:
446        for column in pivot.find_all(exp.Column):
447            if not column.table and column.name in resolver.all_columns:
448                column_table = resolver.get_table(column.name)
449                if column_table:
450                    column.set("table", column_table)
451
452
453def _expand_struct_stars(
454    expression: exp.Dot,
455) -> t.List[exp.Alias]:
456    """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
457
458    dot_column = t.cast(exp.Column, expression.find(exp.Column))
459    if not dot_column.is_type(exp.DataType.Type.STRUCT):
460        return []
461
462    # All nested struct values are ColumnDefs, so normalize the first exp.Column in one
463    dot_column = dot_column.copy()
464    starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
465
466    # First part is the table name and last part is the star so they can be dropped
467    dot_parts = expression.parts[1:-1]
468
469    # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
470    for part in dot_parts[1:]:
471        for field in t.cast(exp.DataType, starting_struct.kind).expressions:
472            # Unable to expand star unless all fields are named
473            if not isinstance(field.this, exp.Identifier):
474                return []
475
476            if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
477                starting_struct = field
478                break
479        else:
480            # There is no matching field in the struct
481            return []
482
483    taken_names = set()
484    new_selections = []
485
486    for field in t.cast(exp.DataType, starting_struct.kind).expressions:
487        name = field.name
488
489        # Ambiguous or anonymous fields can't be expanded
490        if name in taken_names or not isinstance(field.this, exp.Identifier):
491            return []
492
493        taken_names.add(name)
494
495        this = field.this.copy()
496        root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
497        new_column = exp.column(
498            t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts
499        )
500        new_selections.append(alias(new_column, this, copy=False))
501
502    return new_selections
503
504
505def _expand_stars(
506    scope: Scope,
507    resolver: Resolver,
508    using_column_tables: t.Dict[str, t.Any],
509    pseudocolumns: t.Set[str],
510    annotator: TypeAnnotator,
511) -> None:
512    """Expand stars to lists of column selections"""
513
514    new_selections = []
515    except_columns: t.Dict[int, t.Set[str]] = {}
516    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
517    coalesced_columns = set()
518    dialect = resolver.schema.dialect
519
520    pivot_output_columns = None
521    pivot_exclude_columns = None
522
523    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
524    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
525        if pivot.unpivot:
526            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
527
528            field = pivot.args.get("field")
529            if isinstance(field, exp.In):
530                pivot_exclude_columns = {
531                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
532                }
533        else:
534            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
535
536            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
537            if not pivot_output_columns:
538                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
539
540    is_bigquery = dialect == "bigquery"
541    if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
542        # Found struct expansion, annotate scope ahead of time
543        annotator.annotate_scope(scope)
544
545    for expression in scope.expression.selects:
546        tables = []
547        if isinstance(expression, exp.Star):
548            tables.extend(scope.selected_sources)
549            _add_except_columns(expression, tables, except_columns)
550            _add_replace_columns(expression, tables, replace_columns)
551        elif expression.is_star:
552            if not isinstance(expression, exp.Dot):
553                tables.append(expression.table)
554                _add_except_columns(expression.this, tables, except_columns)
555                _add_replace_columns(expression.this, tables, replace_columns)
556            elif is_bigquery:
557                struct_fields = _expand_struct_stars(expression)
558                if struct_fields:
559                    new_selections.extend(struct_fields)
560                    continue
561
562        if not tables:
563            new_selections.append(expression)
564            continue
565
566        for table in tables:
567            if table not in scope.sources:
568                raise OptimizeError(f"Unknown table: {table}")
569
570            columns = resolver.get_source_columns(table, only_visible=True)
571            columns = columns or scope.outer_columns
572
573            if pseudocolumns:
574                columns = [name for name in columns if name.upper() not in pseudocolumns]
575
576            if not columns or "*" in columns:
577                return
578
579            table_id = id(table)
580            columns_to_exclude = except_columns.get(table_id) or set()
581
582            if pivot:
583                if pivot_output_columns and pivot_exclude_columns:
584                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
585                    pivot_columns.extend(pivot_output_columns)
586                else:
587                    pivot_columns = pivot.alias_column_names
588
589                if pivot_columns:
590                    new_selections.extend(
591                        alias(exp.column(name, table=pivot.alias), name, copy=False)
592                        for name in pivot_columns
593                        if name not in columns_to_exclude
594                    )
595                    continue
596
597            for name in columns:
598                if name in columns_to_exclude or name in coalesced_columns:
599                    continue
600                if name in using_column_tables and table in using_column_tables[name]:
601                    coalesced_columns.add(name)
602                    tables = using_column_tables[name]
603                    coalesce_args = [exp.column(name, table=table) for table in tables]
604
605                    new_selections.append(
606                        alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
607                    )
608                else:
609                    alias_ = replace_columns.get(table_id, {}).get(name, name)
610                    column = exp.column(name, table=table)
611                    new_selections.append(
612                        alias(column, alias_, copy=False) if alias_ != name else column
613                    )
614
615    # Ensures we don't overwrite the initial selections with an empty list
616    if new_selections and isinstance(scope.expression, exp.Select):
617        scope.expression.set("expressions", new_selections)
618
619
620def _add_except_columns(
621    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
622) -> None:
623    except_ = expression.args.get("except")
624
625    if not except_:
626        return
627
628    columns = {e.name for e in except_}
629
630    for table in tables:
631        except_columns[id(table)] = columns
632
633
634def _add_replace_columns(
635    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
636) -> None:
637    replace = expression.args.get("replace")
638
639    if not replace:
640        return
641
642    columns = {e.this.name: e.alias for e in replace}
643
644    for table in tables:
645        replace_columns[id(table)] = columns
646
647
648def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
649    """Ensure all output columns are aliased"""
650    if isinstance(scope_or_expression, exp.Expression):
651        scope = build_scope(scope_or_expression)
652        if not isinstance(scope, Scope):
653            return
654    else:
655        scope = scope_or_expression
656
657    new_selections = []
658    for i, (selection, aliased_column) in enumerate(
659        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
660    ):
661        if selection is None:
662            break
663
664        if isinstance(selection, exp.Subquery):
665            if not selection.output_name:
666                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
667        elif not isinstance(selection, exp.Alias) and not selection.is_star:
668            selection = alias(
669                selection,
670                alias=selection.output_name or f"_col_{i}",
671                copy=False,
672            )
673        if aliased_column:
674            selection.set("alias", exp.to_identifier(aliased_column))
675
676        new_selections.append(selection)
677
678    if isinstance(scope.expression, exp.Select):
679        scope.expression.set("expressions", new_selections)
680
681
682def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
683    """Makes sure all identifiers that need to be quoted are quoted."""
684    return expression.transform(
685        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
686    )  # type: ignore
687
688
689def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
690    """
691    Pushes down the CTE alias columns into the projection,
692
693    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
694
695    Example:
696        >>> import sqlglot
697        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
698        >>> pushdown_cte_alias_columns(expression).sql()
699        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
700
701    Args:
702        expression: Expression to pushdown.
703
704    Returns:
705        The expression with the CTE aliases pushed down into the projection.
706    """
707    for cte in expression.find_all(exp.CTE):
708        if cte.alias_column_names:
709            new_expressions = []
710            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
711                if isinstance(projection, exp.Alias):
712                    projection.set("alias", _alias)
713                else:
714                    projection = alias(projection, alias=_alias)
715                new_expressions.append(projection)
716            cte.this.set("expressions", new_expressions)
717
718    return expression
719
720
721class Resolver:
722    """
723    Helper for resolving columns.
724
725    This is a class so we can lazily load some things and easily share them across functions.
726    """
727
728    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
729        self.scope = scope
730        self.schema = schema
731        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
732        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
733        self._all_columns: t.Optional[t.Set[str]] = None
734        self._infer_schema = infer_schema
735        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
736
737    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
738        """
739        Get the table for a column name.
740
741        Args:
742            column_name: The column name to find the table for.
743        Returns:
744            The table name if it can be found/inferred.
745        """
746        if self._unambiguous_columns is None:
747            self._unambiguous_columns = self._get_unambiguous_columns(
748                self._get_all_source_columns()
749            )
750
751        table_name = self._unambiguous_columns.get(column_name)
752
753        if not table_name and self._infer_schema:
754            sources_without_schema = tuple(
755                source
756                for source, columns in self._get_all_source_columns().items()
757                if not columns or "*" in columns
758            )
759            if len(sources_without_schema) == 1:
760                table_name = sources_without_schema[0]
761
762        if table_name not in self.scope.selected_sources:
763            return exp.to_identifier(table_name)
764
765        node, _ = self.scope.selected_sources.get(table_name)
766
767        if isinstance(node, exp.Query):
768            while node and node.alias != table_name:
769                node = node.parent
770
771        node_alias = node.args.get("alias")
772        if node_alias:
773            return exp.to_identifier(node_alias.this)
774
775        return exp.to_identifier(table_name)
776
777    @property
778    def all_columns(self) -> t.Set[str]:
779        """All available columns of all sources in this scope"""
780        if self._all_columns is None:
781            self._all_columns = {
782                column for columns in self._get_all_source_columns().values() for column in columns
783            }
784        return self._all_columns
785
786    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
787        """Resolve the source columns for a given source `name`."""
788        cache_key = (name, only_visible)
789        if cache_key not in self._get_source_columns_cache:
790            if name not in self.scope.sources:
791                raise OptimizeError(f"Unknown table: {name}")
792
793            source = self.scope.sources[name]
794
795            if isinstance(source, exp.Table):
796                columns = self.schema.column_names(source, only_visible)
797            elif isinstance(source, Scope) and isinstance(
798                source.expression, (exp.Values, exp.Unnest)
799            ):
800                columns = source.expression.named_selects
801
802                # in bigquery, unnest structs are automatically scoped as tables, so you can
803                # directly select a struct field in a query.
804                # this handles the case where the unnest is statically defined.
805                if self.schema.dialect == "bigquery":
806                    if source.expression.is_type(exp.DataType.Type.STRUCT):
807                        for k in source.expression.type.expressions:  # type: ignore
808                            columns.append(k.name)
809            else:
810                columns = source.expression.named_selects
811
812            node, _ = self.scope.selected_sources.get(name) or (None, None)
813            if isinstance(node, Scope):
814                column_aliases = node.expression.alias_column_names
815            elif isinstance(node, exp.Expression):
816                column_aliases = node.alias_column_names
817            else:
818                column_aliases = []
819
820            if column_aliases:
821                # If the source's columns are aliased, their aliases shadow the corresponding column names.
822                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
823                columns = [
824                    alias or name
825                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
826                ]
827
828            self._get_source_columns_cache[cache_key] = columns
829
830        return self._get_source_columns_cache[cache_key]
831
832    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
833        if self._source_columns is None:
834            self._source_columns = {
835                source_name: self.get_source_columns(source_name)
836                for source_name, source in itertools.chain(
837                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
838                )
839            }
840        return self._source_columns
841
842    def _get_unambiguous_columns(
843        self, source_columns: t.Dict[str, t.Sequence[str]]
844    ) -> t.Mapping[str, str]:
845        """
846        Find all the unambiguous columns in sources.
847
848        Args:
849            source_columns: Mapping of names to source columns.
850
851        Returns:
852            Mapping of column name to source name.
853        """
854        if not source_columns:
855            return {}
856
857        source_columns_pairs = list(source_columns.items())
858
859        first_table, first_columns = source_columns_pairs[0]
860
861        if len(source_columns_pairs) == 1:
862            # Performance optimization - avoid copying first_columns if there is only one table.
863            return SingleValuedMapping(first_columns, first_table)
864
865        unambiguous_columns = {col: first_table for col in first_columns}
866        all_columns = set(unambiguous_columns)
867
868        for table, columns in source_columns_pairs[1:]:
869            unique = set(columns)
870            ambiguous = all_columns.intersection(unique)
871            all_columns.update(columns)
872
873            for column in ambiguous:
874                unambiguous_columns.pop(column, None)
875            for column in unique.difference(ambiguous):
876                unambiguous_columns[column] = table
877
878        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
20def qualify_columns(
21    expression: exp.Expression,
22    schema: t.Dict | Schema,
23    expand_alias_refs: bool = True,
24    expand_stars: bool = True,
25    infer_schema: t.Optional[bool] = None,
26) -> exp.Expression:
27    """
28    Rewrite sqlglot AST to have fully qualified columns.
29
30    Example:
31        >>> import sqlglot
32        >>> schema = {"tbl": {"col": "INT"}}
33        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
34        >>> qualify_columns(expression, schema).sql()
35        'SELECT tbl.col AS col FROM tbl'
36
37    Args:
38        expression: Expression to qualify.
39        schema: Database schema.
40        expand_alias_refs: Whether to expand references to aliases.
41        expand_stars: Whether to expand star queries. This is a necessary step
42            for most of the optimizer's rules to work; do not set to False unless you
43            know what you're doing!
44        infer_schema: Whether to infer the schema if missing.
45
46    Returns:
47        The qualified expression.
48
49    Notes:
50        - Currently only handles a single PIVOT or UNPIVOT operator
51    """
52    schema = ensure_schema(schema)
53    annotator = TypeAnnotator(schema)
54    infer_schema = schema.empty if infer_schema is None else infer_schema
55    dialect = Dialect.get_or_raise(schema.dialect)
56    pseudocolumns = dialect.PSEUDOCOLUMNS
57
58    for scope in traverse_scope(expression):
59        resolver = Resolver(scope, schema, infer_schema=infer_schema)
60        _pop_table_column_aliases(scope.ctes)
61        _pop_table_column_aliases(scope.derived_tables)
62        using_column_tables = _expand_using(scope, resolver)
63
64        if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
65            _expand_alias_refs(
66                scope,
67                resolver,
68                expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
69            )
70
71        _convert_columns_to_dots(scope, resolver)
72        _qualify_columns(scope, resolver)
73
74        if not schema.empty and expand_alias_refs:
75            _expand_alias_refs(scope, resolver)
76
77        if not isinstance(scope.expression, exp.UDTF):
78            if expand_stars:
79                _expand_stars(
80                    scope,
81                    resolver,
82                    using_column_tables,
83                    pseudocolumns,
84                    annotator,
85                )
86            qualify_outputs(scope)
87
88        _expand_group_by(scope, dialect)
89        _expand_order_by(scope, resolver)
90
91        if dialect == "bigquery":
92            annotator.annotate_scope(scope)
93
94    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 97def validate_qualify_columns(expression: E) -> E:
 98    """Raise an `OptimizeError` if any columns aren't qualified"""
 99    all_unqualified_columns = []
100    for scope in traverse_scope(expression):
101        if isinstance(scope.expression, exp.Select):
102            unqualified_columns = scope.unqualified_columns
103
104            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
105                column = scope.external_columns[0]
106                for_table = f" for table: '{column.table}'" if column.table else ""
107                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
108
109            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
110                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
111                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
112                # this list here to ensure those in the former category will be excluded.
113                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
114                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
115
116            all_unqualified_columns.extend(unqualified_columns)
117
118    if all_unqualified_columns:
119        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
120
121    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
649def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
650    """Ensure all output columns are aliased"""
651    if isinstance(scope_or_expression, exp.Expression):
652        scope = build_scope(scope_or_expression)
653        if not isinstance(scope, Scope):
654            return
655    else:
656        scope = scope_or_expression
657
658    new_selections = []
659    for i, (selection, aliased_column) in enumerate(
660        itertools.zip_longest(scope.expression.selects, scope.outer_columns)
661    ):
662        if selection is None:
663            break
664
665        if isinstance(selection, exp.Subquery):
666            if not selection.output_name:
667                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
668        elif not isinstance(selection, exp.Alias) and not selection.is_star:
669            selection = alias(
670                selection,
671                alias=selection.output_name or f"_col_{i}",
672                copy=False,
673            )
674        if aliased_column:
675            selection.set("alias", exp.to_identifier(aliased_column))
676
677        new_selections.append(selection)
678
679    if isinstance(scope.expression, exp.Select):
680        scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
683def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
684    """Makes sure all identifiers that need to be quoted are quoted."""
685    return expression.transform(
686        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
687    )  # type: ignore

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
690def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
691    """
692    Pushes down the CTE alias columns into the projection,
693
694    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
695
696    Example:
697        >>> import sqlglot
698        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
699        >>> pushdown_cte_alias_columns(expression).sql()
700        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
701
702    Args:
703        expression: Expression to pushdown.
704
705    Returns:
706        The expression with the CTE aliases pushed down into the projection.
707    """
708    for cte in expression.find_all(exp.CTE):
709        if cte.alias_column_names:
710            new_expressions = []
711            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
712                if isinstance(projection, exp.Alias):
713                    projection.set("alias", _alias)
714                else:
715                    projection = alias(projection, alias=_alias)
716                new_expressions.append(projection)
717            cte.this.set("expressions", new_expressions)
718
719    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
722class Resolver:
723    """
724    Helper for resolving columns.
725
726    This is a class so we can lazily load some things and easily share them across functions.
727    """
728
729    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
730        self.scope = scope
731        self.schema = schema
732        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
733        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
734        self._all_columns: t.Optional[t.Set[str]] = None
735        self._infer_schema = infer_schema
736        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
737
738    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
739        """
740        Get the table for a column name.
741
742        Args:
743            column_name: The column name to find the table for.
744        Returns:
745            The table name if it can be found/inferred.
746        """
747        if self._unambiguous_columns is None:
748            self._unambiguous_columns = self._get_unambiguous_columns(
749                self._get_all_source_columns()
750            )
751
752        table_name = self._unambiguous_columns.get(column_name)
753
754        if not table_name and self._infer_schema:
755            sources_without_schema = tuple(
756                source
757                for source, columns in self._get_all_source_columns().items()
758                if not columns or "*" in columns
759            )
760            if len(sources_without_schema) == 1:
761                table_name = sources_without_schema[0]
762
763        if table_name not in self.scope.selected_sources:
764            return exp.to_identifier(table_name)
765
766        node, _ = self.scope.selected_sources.get(table_name)
767
768        if isinstance(node, exp.Query):
769            while node and node.alias != table_name:
770                node = node.parent
771
772        node_alias = node.args.get("alias")
773        if node_alias:
774            return exp.to_identifier(node_alias.this)
775
776        return exp.to_identifier(table_name)
777
778    @property
779    def all_columns(self) -> t.Set[str]:
780        """All available columns of all sources in this scope"""
781        if self._all_columns is None:
782            self._all_columns = {
783                column for columns in self._get_all_source_columns().values() for column in columns
784            }
785        return self._all_columns
786
787    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
788        """Resolve the source columns for a given source `name`."""
789        cache_key = (name, only_visible)
790        if cache_key not in self._get_source_columns_cache:
791            if name not in self.scope.sources:
792                raise OptimizeError(f"Unknown table: {name}")
793
794            source = self.scope.sources[name]
795
796            if isinstance(source, exp.Table):
797                columns = self.schema.column_names(source, only_visible)
798            elif isinstance(source, Scope) and isinstance(
799                source.expression, (exp.Values, exp.Unnest)
800            ):
801                columns = source.expression.named_selects
802
803                # in bigquery, unnest structs are automatically scoped as tables, so you can
804                # directly select a struct field in a query.
805                # this handles the case where the unnest is statically defined.
806                if self.schema.dialect == "bigquery":
807                    if source.expression.is_type(exp.DataType.Type.STRUCT):
808                        for k in source.expression.type.expressions:  # type: ignore
809                            columns.append(k.name)
810            else:
811                columns = source.expression.named_selects
812
813            node, _ = self.scope.selected_sources.get(name) or (None, None)
814            if isinstance(node, Scope):
815                column_aliases = node.expression.alias_column_names
816            elif isinstance(node, exp.Expression):
817                column_aliases = node.alias_column_names
818            else:
819                column_aliases = []
820
821            if column_aliases:
822                # If the source's columns are aliased, their aliases shadow the corresponding column names.
823                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
824                columns = [
825                    alias or name
826                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
827                ]
828
829            self._get_source_columns_cache[cache_key] = columns
830
831        return self._get_source_columns_cache[cache_key]
832
833    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
834        if self._source_columns is None:
835            self._source_columns = {
836                source_name: self.get_source_columns(source_name)
837                for source_name, source in itertools.chain(
838                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
839                )
840            }
841        return self._source_columns
842
843    def _get_unambiguous_columns(
844        self, source_columns: t.Dict[str, t.Sequence[str]]
845    ) -> t.Mapping[str, str]:
846        """
847        Find all the unambiguous columns in sources.
848
849        Args:
850            source_columns: Mapping of names to source columns.
851
852        Returns:
853            Mapping of column name to source name.
854        """
855        if not source_columns:
856            return {}
857
858        source_columns_pairs = list(source_columns.items())
859
860        first_table, first_columns = source_columns_pairs[0]
861
862        if len(source_columns_pairs) == 1:
863            # Performance optimization - avoid copying first_columns if there is only one table.
864            return SingleValuedMapping(first_columns, first_table)
865
866        unambiguous_columns = {col: first_table for col in first_columns}
867        all_columns = set(unambiguous_columns)
868
869        for table, columns in source_columns_pairs[1:]:
870            unique = set(columns)
871            ambiguous = all_columns.intersection(unique)
872            all_columns.update(columns)
873
874            for column in ambiguous:
875                unambiguous_columns.pop(column, None)
876            for column in unique.difference(ambiguous):
877                unambiguous_columns[column] = table
878
879        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
729    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
730        self.scope = scope
731        self.schema = schema
732        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
733        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
734        self._all_columns: t.Optional[t.Set[str]] = None
735        self._infer_schema = infer_schema
736        self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
738    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
739        """
740        Get the table for a column name.
741
742        Args:
743            column_name: The column name to find the table for.
744        Returns:
745            The table name if it can be found/inferred.
746        """
747        if self._unambiguous_columns is None:
748            self._unambiguous_columns = self._get_unambiguous_columns(
749                self._get_all_source_columns()
750            )
751
752        table_name = self._unambiguous_columns.get(column_name)
753
754        if not table_name and self._infer_schema:
755            sources_without_schema = tuple(
756                source
757                for source, columns in self._get_all_source_columns().items()
758                if not columns or "*" in columns
759            )
760            if len(sources_without_schema) == 1:
761                table_name = sources_without_schema[0]
762
763        if table_name not in self.scope.selected_sources:
764            return exp.to_identifier(table_name)
765
766        node, _ = self.scope.selected_sources.get(table_name)
767
768        if isinstance(node, exp.Query):
769            while node and node.alias != table_name:
770                node = node.parent
771
772        node_alias = node.args.get("alias")
773        if node_alias:
774            return exp.to_identifier(node_alias.this)
775
776        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
778    @property
779    def all_columns(self) -> t.Set[str]:
780        """All available columns of all sources in this scope"""
781        if self._all_columns is None:
782            self._all_columns = {
783                column for columns in self._get_all_source_columns().values() for column in columns
784            }
785        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
787    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
788        """Resolve the source columns for a given source `name`."""
789        cache_key = (name, only_visible)
790        if cache_key not in self._get_source_columns_cache:
791            if name not in self.scope.sources:
792                raise OptimizeError(f"Unknown table: {name}")
793
794            source = self.scope.sources[name]
795
796            if isinstance(source, exp.Table):
797                columns = self.schema.column_names(source, only_visible)
798            elif isinstance(source, Scope) and isinstance(
799                source.expression, (exp.Values, exp.Unnest)
800            ):
801                columns = source.expression.named_selects
802
803                # in bigquery, unnest structs are automatically scoped as tables, so you can
804                # directly select a struct field in a query.
805                # this handles the case where the unnest is statically defined.
806                if self.schema.dialect == "bigquery":
807                    if source.expression.is_type(exp.DataType.Type.STRUCT):
808                        for k in source.expression.type.expressions:  # type: ignore
809                            columns.append(k.name)
810            else:
811                columns = source.expression.named_selects
812
813            node, _ = self.scope.selected_sources.get(name) or (None, None)
814            if isinstance(node, Scope):
815                column_aliases = node.expression.alias_column_names
816            elif isinstance(node, exp.Expression):
817                column_aliases = node.alias_column_names
818            else:
819                column_aliases = []
820
821            if column_aliases:
822                # If the source's columns are aliased, their aliases shadow the corresponding column names.
823                # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
824                columns = [
825                    alias or name
826                    for (name, alias) in itertools.zip_longest(columns, column_aliases)
827                ]
828
829            self._get_source_columns_cache[cache_key] = columns
830
831        return self._get_source_columns_cache[cache_key]

Resolve the source columns for a given source name.