Extracting the Module and Function Names from Python ASTs

python
Author

Arumoy Shome

Published

March 23, 2024

Abstract
How to extract the module and function name from Python Abstract Syntax Trees.

Preliminaries: Python ast module

Python has a built-in ast module which provides detailed documentation on the various nodes used to represent different elements of Python source code.

We can use the ast.parse method to create a AST from a given Python source code. Here is an example of how a function call is represented in a AST.

import ast
ast.parse("foo(x, y)")
<ast.Module at 0x1063c3950>

The ast.parse method returns a ast.Module object which by itself is not very helpful. To view the internal structure of the tree, we can use the ast.dump method.

ast.dump(ast.parse("foo(x, y)"))
"Module(body=[Expr(value=Call(func=Name(id='foo', ctx=Load()), args=[Name(id='x', ctx=Load()), Name(id='y', ctx=Load())]))])"

We can pass the indent argument to ast.dump along with a print statement to make the output more readable.

Listing 1: Base case
print(ast.dump(ast.parse("foo(x, y)"), indent=4))
Module(
    body=[
        Expr(
            value=Call(
                func=Name(id='foo', ctx=Load()),
                args=[
                    Name(id='x', ctx=Load()),
                    Name(id='y', ctx=Load())]))])

Note that the function call is represented by a ast.Call node which contains a func and args attribute. The function name (in our case, foo) is represented by a ast.Name node, with the actual name under the id attribute.

And here is the AST when we want to use a function defined in a different module.

Listing 2: Single nested function call
print(ast.dump(ast.parse("bar.foo(x, y)"), indent=4))
Module(
    body=[
        Expr(
            value=Call(
                func=Attribute(
                    value=Name(id='bar', ctx=Load()),
                    attr='foo',
                    ctx=Load()),
                args=[
                    Name(id='x', ctx=Load()),
                    Name(id='y', ctx=Load())]))])

Things are a bit different now. We see that the Call.func is no longer a ast.Name node, but instead an ast.Attribute node. Attribute.value is now a ast.Name node with the name of the module (in our case bar) on the id attribute and the name of the function on the attr attribute.

Lets examine something a bit more complicated: What if the function is in a submodule?

print(ast.dump(ast.parse("baz.bar.foo(x, y)"), indent=4))
Module(
    body=[
        Expr(
            value=Call(
                func=Attribute(
                    value=Attribute(
                        value=Name(id='baz', ctx=Load()),
                        attr='bar',
                        ctx=Load()),
                    attr='foo',
                    ctx=Load()),
                args=[
                    Name(id='x', ctx=Load()),
                    Name(id='y', ctx=Load())]))])

And even more nested?

print(ast.dump(ast.parse("quack.baz.bar.foo(x, y)"), indent=4))
Module(
    body=[
        Expr(
            value=Call(
                func=Attribute(
                    value=Attribute(
                        value=Attribute(
                            value=Name(id='quack', ctx=Load()),
                            attr='baz',
                            ctx=Load()),
                        attr='bar',
                        ctx=Load()),
                    attr='foo',
                    ctx=Load()),
                args=[
                    Name(id='x', ctx=Load()),
                    Name(id='y', ctx=Load())]))])

It seems that nested function calls are represented using nested ast.Attribute nodes. The top level module name is always a ast.Name node under the deepest ast.Attribute.value node. And the function name is always under the first ast.Attribute.attr node

Extracting the Function Names

Lets start with the simplest case, where we are only interested in the function names (ie. we only want to extract foo from all scenarios presented above). There are two cases to consider here:

  1. If its a direct function call, then the Call.func node will contain a Name node.
  2. If its a nested function call, then Call.func will contain nested Attribute nodes. The name of the function will be under the first Attribute.attr.

We can do this using the ast.NodeVisitor class. Lets create a FunctionNameCollector class which inherits from ast.NodeVisitor. In the class, we define a visit_Name and visit_Attribute methods which are called every time we visit a Name or Attribute method respectively (more on this later).

class FunctionNameCollector(ast.NodeVisitor):
    def __init__(self):
        self.names = []

    def visit_Name(self, node: ast.Name) -> None:
        self.names.append(node.id)

    def visit_Attribute(self, node: ast.Attribute) -> None:
        self.names.append(node.attr)


tests = ["foo(x, y)", "bar.foo(x, y)", "baz.bar.foo(x, y)", "quack.baz.bar.foo(x, y)"]

trees = [ast.parse(test) for test in tests]
call_nodes = [
    node for tree in trees for node in ast.walk(tree) if isinstance(node, ast.Call)
]
collector = FunctionNameCollector()
for node in call_nodes:
    collector.visit(node.func)

collector.names
['foo', 'foo', 'foo', 'foo']

I collect all the ast.Call nodes in our test cases using the ast.walk function which returns a generator that yields every child node under the given AST. Then I call the visit method provided by ast.NodeVisitor which visits only the direct child nodes of all Call.func nodes in our test cases.

Collecting Both Module and Function Names

Here is where things get a bit more interesting. Here are the cases to consider:

  1. The base case is that we have a direct function call, in which case we need to extract the function name from Name.id under Call.func (same as before).
  2. However, if it is a nested function call, then:
    1. The function name will be under the first Attribute.attr and
    2. The module name will be under the last Attribute.value.id.

So the visit_Name method remains the same however, we do need to modify the visit_Attribute method such that it traverses all child nodes under Attribute.value until we hit the base case. Here is the modified code.

class NameCollector(ast.NodeVisitor):
    def __init__(self):
        self.names = []
        self.stack = []

    def visit_Name(self, node: ast.Name) -> None:
        if self.stack:
            self.names.append((node.id, self.stack[0].attr))
        else:
            self.names.append((None, node.id))

    def visit_Attribute(self, node: ast.Attribute) -> None:
        self.stack.append(node)
        self.visit(node.value)

collector = NameCollector()
for node in call_nodes:
    collector.visit(node.func)

collector.names
[(None, 'foo'), ('bar', 'foo'), ('baz', 'foo'), ('quack', 'foo')]

The code is similar to FunctionNameCollector defined above, with a few key changes. The collector.names now returns a list of tuples containing the module, function names.

I use a stack to keep track of the Attribute nodes we visit. Whenever we visit an Attribute node, I append it to the stack and then call the NodeVisitor.visit method on the Node under its value attribute.

When we visit a Name node, it can either be because we are at the base case (a direct function call) or because we have reached the last Attribute.value node. If the stack is not empty, then its the latter which means the Name node is the module name and the function name is under the first Attribute.attr in the stack. Otherwise, its a direct function call so the Name node is the function name. Since we don’t have a module in this case, we return None in the tuple.

Back to top