How to annotate PostgreSQL ASTs with location information

This blog post originally appeared on Propel's blog. You can also view it there.

If you're writing tools to inspect and manipulate PostgreSQL queries, then you might already be familiar with libpg_query. If not, allow me to introduce it: libpg_query is the actual parser used by PostgreSQL to turn SQL queries into abstract syntax trees (ASTs) packaged up into a reusable C library.

Compared to bespoke parsers, libpg_query offers near-perfect compatibility with PostgreSQL because it's the same code. That's awesome! However, if you compare the output of libpg_query to ANTLR- or Nearley-based parsers, you'll notice that, while the PostgreSQL AST includes start locations for every expression, it lacks the length or end locations that other parsers provide. This makes writing syntax or expression highlighters difficult, because you can't be sure where an expression ends and when to stop highlighting.

Thankfully, there is a technique using the PostgreSQL lexer that we can use to recover expressions' end locations, and that's what I'll show you in the rest of this post.

From C to TypeScript

First, rather than write this blog post in C, let's write it in TypeScript using Node.js. To do so, we need a library that loads libpg_query as a Node.js native addon. Up until recently, libpg-query would've been the clear choice for this; however, a fork with improved TypeScript definitions has emerged, @pg-nano/pg-parser, so let's use that:

cd $(mktemp -d)
npm init -y
npm install @pg-nano/pg-parser

Consider the following SELECT statement:

SELECT 1 + 2 AS three

We can parse it with parseQuerySync and print it to the console:

import { inspect } from 'node:util'

import { parseQuerySync } from '@pg-nano/pg-parser'

let input = 'SELECT 1 + 2 AS three'
let parseResult = parseQuerySync(input)
console.log(inspect(parseResult, false, null, true))

Doing so logs the following:

{
  version: 160001,
  stmts: [
    {
      stmt: {
        SelectStmt: {
          targetList: [
            {
              ResTarget: {
                name: 'three',
                val: {
                  A_Expr: {
                    kind: 'AEXPR_OP',
                    name: [ { String: { sval: '+' } } ],
                    lexpr: { A_Const: { ival: { ival: 1 }, location: 7 } },
                    rexpr: { A_Const: { ival: { ival: 2 }, location: 11 } },
                    location: 9
                  }
                },
                location: 7
              }
            }
          ],
          limitOption: 'LIMIT_OPTION_DEFAULT',
          op: 'SETOP_NONE'
        }
      }
    }
  ]
}

Amazing! Check out the location fields on the AST nodes. They tell us where each expression occurred in the original SQL text, but they don't tell us where the expressions start or stop. Instead, they include just enough information for PostgreSQL to associate error messages to expressions. For example, imagine we execute the following SQL statement, which incorrectly tries to sum a number and a string:

SELECT 1 + 'two' AS three;

PostgreSQL responds with the following error message, placing a caret symbol (^) at the start of the string literal "two", which is location 11 in the AST:

psql:commands.sql:1: ERROR:  invalid input syntax for type integer: "two"
LINE 1: SELECT 1 + 'two' AS three;
                   ^

So that's cool, but what if we wanted to highlight or underline individual expressions? How can we go from a single location field to a pair of start and stop locations?

Getting tokens from the PostgreSQL lexer

libpg_query exposes a function, pg_query_scan, which takes a SQL text as input and passes it to PostgreSQL's lexer for tokenization. Tokenization is the first step in parsing a PostgreSQL statement, where the input text is divided up and categorized into lexemes, or "tokens". Each token includes both its start and end location. So, by combining the tokens with the AST, we should be able to recover the start and end locations of expressions, too.

Since we're using TypeScript and Node.js, we need to expose pg_query_scan there, too. I've opened a pull request to do this, and I hope it will be merged. In the meantime, I've published my own release which we'll use for the remainder of the blog post:

npm install @markandrus/pg-parser

Now we can call scanSync to get our SELECT statement's tokens and print them to the console:

import { scanSync } from '@markandrus/pg-parser'

let tokens = scanSync(input)
console.log(inspect(tokens, false, null, true))

Doing so logs the following:

[
  { kind: 'SELECT', start: 0, end: 6, keyword: 'RESERVED_KEYWORD' },
  { kind: 'ICONST', start: 7, end: 8, keyword: 'NO_KEYWORD' },
  { kind: 'ASCII_43', start: 9, end: 10, keyword: 'NO_KEYWORD' },
  { kind: 'ICONST', start: 11, end: 12, keyword: 'NO_KEYWORD' },
  { kind: 'AS', start: 13, end: 15, keyword: 'RESERVED_KEYWORD' },
  { kind: 'IDENT', start: 16, end: 21, keyword: 'NO_KEYWORD' }
]

Notice how each token contains a start and end location, and kind tells us something about the token. For example, ICONST tokens represent constants, and IDENT tokens represent identifiers.

Sketching a solution

Take a moment to review each token's start and end locations above, comparing them to our SELECT statement. If we we to underline each token, it would look like this:

SELECT 1 + 2 AS three
└────┘ ^ ^ ^ └┘ └───┘

Let's instead imagine what it would look like to underline each expression, starting with leaf expressions first:

SELECT 1 + 2 AS three
       ^

SELECT 1 + 2 AS three
           ^

SELECT 1 + 2 AS three
       └───┘

SELECT 1 + 2 AS three
       └────────────┘

SELECT 1 + 2 AS three
└───────────────────┘

Notice how

  • Every leaf expression maps to a token by way of its start location, and that token's start and end locations define the expression's start and end locations.
  • Every non-leaf expression's start and end location can be defined by the start and end locations of its childrens' left- and right-most tokens, respectively.

From these two observations, we can define a base case and recursive case for a set of recursive functions that will visit the PostgreSQL AST and annotate each expression with its left- and right-most tokens. So let's write it!

Types and basic operations

Let's start by importing and defining some types and basic operations. To begin, let's define a Span type which holds a left- and right-most Token. We can say that a Span starts at its left Token's start position and ends at its right Token's end position:

import type { Token } from '@markandrus/pg-parser'

interface Span {
  left: Token
  right: Token
}

Given a Token, we can construct a trivial Span by setting left and right equal to each other:

function newSpan (token: Token): Span {
  return { left: token, right: token }
}

We can also merge Spans by taking the left- and right-most Tokens of the two Spans. For reasons that will become clear later, it's convenient to let the first argument of mergeSpans be optional:

function mergeSpans (s1: Span | undefined, s2: Span): Span {
  if (s1 == null) return s2

  return {
    left: takeLeftToken(s1.left, s2.left),
    right: takeRightToken(s1.right, s2.right)
  }
}

Taking the left- and right-most Token is straightforward:

function takeLeftToken (t1: Token, t2: Token): Token {
  return t1.start <= t2.start ? t1 : t2
}

function takeRightToken (t1: Token, t2: Token): Token {
  return t2.end >= t1.end ? t2 : t1
}

Next, let's define a generic expression type, Expr, and a type of State object that we will thread through our recursive functions. The State object includes a map from expression locations to Tokens and a map from expressions to Spans. As we visit the PostgreSQL AST, we'll look up tokens by expression location and update exprToSpan with our calculated Spans.

type Expr = Record<string, unknown>

interface State {
  locationToToken: Map<number, Token>
  exprToSpan: Map<Expr, Span>
}

Finally, let's define newState to initialize a State object using the Tokens returned by scanSync:

function newState (tokens: Token[]): State {
  const locationToToken = new Map<number, Token>()

  for (let index = 0; index < tokens.length; index++) {
    const token = tokens[index]
    locationToToken.set(token.start, token)
  }

  return {
    locationToToken,
    exprToSpan: new Map(),
  }
}

Visiting and annotating the AST

We want to write a function that recurses through the AST, visiting every expression. Then, for each expression, we want to annotate it with its Span. We can make a few choices here. For example,

  • Do we use the @pg-nano/pg-parser TypeScript definitions to perform strongly typed recursion, or do we recurse through every array and object?
  • Do we annotate expressions by storing Spans directly on the AST, or do we store them separately in a map?

Let's keep it simple for the blog post and just recurse through every array and object. We'll do that by defining a top-level function, getSpan, and two helper functions getArraySpan and getObjectSpan. Additionally, let's not store Spans directly on the AST; instead, let's store them in our State object's exprToSpan map. With that in mind, let's start with our recursive functions' entrypoint: getSpan.

getSpan

function getSpan (state: State, expr: unknown): Span | undefined {
  if (expr == null || typeof expr !== 'object') return undefined
  else if (Array.isArray(expr)) return getArraySpan(state, expr)
  else return getObjectSpan(state, expr as Record<string, unknown>)
}

Notice how if expr is null, undefined, or any other primitive type, then getSpan returns undefined immediately. This is because there are no children to visit and there is no location information to look up a Token from. Otherwise, if expr is an array we call out to getArraySpan, and if expr is an object we call out to getObjectSpan.

getArraySpan

function getArraySpan (state: State, exprs: unknown[]): Span | undefined {
  let span: Span | undefined

  for (const expr of exprs) {
    const childSpan = getSpan(state, expr)
    if (childSpan == null) continue
    span = mergeSpans(span, childSpan)
  }

  return span
}

In getArraySpan, we start by declaring an undefined span. Then, for each array element, we get the element's childSpan by recursively calling getSpan and merge it into the existing span before finally returning a result.

getObjectSpan

The object case is similar to the array case, except instead of iterating over array elements, we iterate over object members. We also perform a few extra steps…

function getObjectSpan (state: State, expr: Expr): Span | undefined {
  let span = state.exprToSpan.get(expr)
  if (span != null) return span

  let isExpression = false
  if (typeof expr.location === 'number' && expr.location >= 0) {
    isExpression = true
    const token = state.locationToToken.get(expr.location)
    span = newSpan(token!)
  }

  for (const key in expr) {
    const childSpan = getSpan(state, expr[key])
    if (childSpan == null) continue
    span = mergeSpans(span, childSpan)
  }

  if (isExpression && span != null) {
    state.exprToSpan.set(expr, span)
  }

  return span
}

First, we check if we've already got a Span for expr in exprsToSpan. If so, we return immediately. Otherwise, we check for a location field to determine if expr is an expression whose Span we should save. If so, we initialize span to its starting Token, and we update exprToSpan before returning.

Testing it out

Test the function on the first ResTarget node in parseResult and print the result:

let state = newState(tokens)
let span = getSpan(state, (parseResult.stmts[0].stmt as any).SelectStmt.targetList[0].ResTarget)
console.log(inspect(span, false, null, true))

Doing so should log tokens representing the range 7–12:

{
  left: { kind: 'ICONST', start: 7, end: 8, keyword: 'NO_KEYWORD' },
  right: { kind: 'ICONST', start: 11, end: 12, keyword: 'NO_KEYWORD' }
}

Visualizing Spans

Now that we can calculate Spans for every expression in exprToSpan, we can use that information to highlight or underline expressions in the AST. Let's use box-drawing characters to underline expressions, like we've been doing in the examples above. Start by defining a function, printWithSpan, that takes a string and optional Span, and prints them to the console:

function printWithSpan (input: string, span?: Span): void {
  console.log(input)

  if (span == null) {
    console.log('\\n')
    return
  }

  const start = span.left.start
  const end = span.right.end
  if (start === end - 1) console.log('^'.padStart(start + 1) + '\\n')
  else console.log(''.padStart(start + 1, ' ') + ''.padStart(end - start - 1, '') + '\\n')
}

Call that function, and it will print. For example, we can print the Span for the ResTarget node that represents the arithmetic expression:

span = getSpan(state, (parseResult.stmts[0].stmt as any).SelectStmt.targetList[0].ResTarget)
printWithSpan(input, span)

Doing so logs the following to the console:

SELECT 1 + 2 AS three
       └───┘

What next?

If you play around with this approach, you'll quickly notice some improvements we could make to how Spans are calculated. We'll introduce a few such cases and potential solutions, but solving these problems is left as an exercise to the reader!

Aliases

Notice how the Span for the ResTarget node above excludes "AS alias". In fact, the alias is part of the ResTarget node (it's included in its name property), and so it would be better if the Span included this:

SELECT 1 + 2 AS three
       └────────────┘

To fix this, we could update our getObject function to use the types from @pg-nano/pg-parser, match on ResTarget, extract its name, and consume additional tokens from the end of its Span through "AS alias". Keep in mind that (1) the "AS" token is optional, and (2) we should ignore any comment tokens.

Parentheses

Notice how the Spans for parenthesized expressions are calculated incorrectly:

input = 'SELECT (((1 + 2)))'
parseResult = parseQuerySync(input)
tokens = scanSync(input)
state = newState(tokens)
span = getSpan(state, (parseResult.stmts[0].stmt as any).SelectStmt.targetList[0].ResTarget)
printWithSpan(input, span)

Running this logs the following to the console:

SELECT (((1 + 2)))
       └──────┘

To fix this, we can update our getObject function to count unmatched parentheses within an expression's Span, and then nudge the left- and right-most tokens of the Span to include the missing parentheses. Again, keep in mind that we should ignore any comment tokens.

SELECT COUNT(*) FROM t

Take a look at how PostgreSQL parses the following SELECT statement:

input = 'SELECT COUNT(*) FROM t'
parseResult = parseQuerySync(input)
console.log(inspect(parseResult, false, null, true))

The expression COUNT(*) is represented as a FuncCall with agg_star set to true:

{
  "FuncCall": {
    "funcname": [
      {
        "String": {
          "sval": "count"
        }
      }
    ],
    "agg_star": true,
    "funcformat": "COERCE_EXPLICIT_CALL",
    "location": 7
  }
}

If we print the Span for this FuncCall, we see that it excludes the parentheses and asterisk:

tokens = scanSync(input)
state = newState(tokens)
span = getSpan(state, (parseResult.stmts[0].stmt as any).SelectStmt.targetList[0].ResTarget)
printWithSpan(input, span)

Running this logs the following to the console:

SELECT COUNT(*) FROM t
       └───┘

Ideally, we'd like to print the following:

SELECT COUNT(*) FROM t
       └──────┘

To fix this, we can update our getObject function to use the types from @pg-nano/pg-parser, match on FuncCall, and extract its agg_star. Then, if agg_star is true, we can consume a left parenthesis token, an asterisk token, and a right parenthesis token. Again, keep in mind that we must skip over comment tokens.

The SELECT token

Perhaps most basic of all, if we print the Span for the entire parseResult, it excludes the SELECT token!

span = getSpan(state, parseResult)
printWithSpan(input, span)

Running this logs the following to the console:

SELECT COUNT(*) FROM t
       └─────────────┘

Ideally, we'd like to print the following:

SELECT COUNT(*) FROM t
└────────────────────┘

To fix this, we can update our getObject function to use the types from @pg-nano/pg-parser, match on SelectStmt, and consume its initial SELECT token, keeping in mind that we must skip over comment tokens.

Try it yourself!

This blog post was written as literate TypeScript embedded in Markdown. You can extract and run the TypeScript code from the blog post by using codedown and piping the output to Node.js with the --experimental-strip-types flag:

cd $(mktemp -d)
npm init -y
npm install @pg-nano/pg-parser @markandrus/pg-parser
curl <https://gist.githubusercontent.com/markandrus/859f5aa97c088c202e42142d9e876b01/raw/7b7fc019e59a50e63952bde9bfddc147cbec5910/how_to_annotate_postgresql_asts_with_location_information.md> >post.md
npx codedown ts <post.md >post.ts
node --experimental-strip-types post.ts