How to Flatten a 2D List in Python


Suppose we have a list of lists, or 2D list, in Python.

lst_of_lsts = [[1, 2, 3], [4, 5, 6], [7, 8]]

We want to flatten this list, or convert it into a 1D list.

flattened_lst = [1, 2, 3, 4, 5, 6, 7, 8]

Using a for Loop

We can use two simple for loops to obtain the flattened list.

flattened_lst = []
for sublist in lst_of_lsts:
    for elem in sublist:
        flattened_lst.append(elem)

This is a very intuitive approach that gives us the correct flattened list.

Using List Comprehension

We can simplify the syntax a bit using list comprehension.

flattened_lst = [elem for sublist in lst_of_lsts for elem in sublist]

This is essentially the same as the for loop method above.

Using reduce()

We can use reduce() to achieve the same output.

reduce(func, iter) will take in 2 parameters: a function func and an iterable iter.

func(a, b) will take two parameters and perform some operation on them. In our case, it will be operator.concat, which will concatenate, or combine, each element in iter while following the reducer/accumulator pattern.

iter is simply an iterable object (list, tuple, dictionary, etc).

import functools
flattened_lst = functools.reduce(operator.concat, lst_of_lsts)

Using chain.from_iterable()

Lastly, we can use the from_iterable() function.

import itertools
flattened_lst = list(itertools.chain.from_iterable(lst_of_lsts))

This from_iterable() function essentially runs a double for loop like in our first method.

def from_iterable(iterables):
    for it in iterables:
        for element in it:
            yield element

However, it returns an itertools.chain object that we could use to iterate through every value in our list of lists.

In order to get the flattened list, we can convert this object to a single list using the list() type conversion.

Using chain() and the Star Operator *

The single star operator * unpacks the sequence or collection into positional arguments. We can do something like this:

def sum(a, b):
    return a + b
nums = (1, 2)
res = sum(*nums)

With this in mind, we can pass in these positional arguments into the chain() function from the previous method.

import itertools
flattened_lst = list(itertools.chain(*lst_of_lsts))