Code Polish [01] — One hot encoding

Mycchaka Kleinbort
5 min readNov 17, 2022

--

Fixing a watch — image from https://bit.ly/3V37CS9

I couldn't say if this is a good use of time or not, but I absolutely love making little snippets of code as fast and elegant as possible.

To that end, this is a first of a series of brain teasers where I look to take “good” code and make it as fast and beautiful as I know how.

Q: How do you one-hot-encode a column that is a list of categories?

This is a recent question from a friend. The set up is as follows:

You have a dataframe df with a column segements. Each entry is a string that represents the groups each user belongs to.

For example, '12,14,57'means the user belongs to segments '12', '14' and '57'.

If you want to follow along this this code snippet will generate some dummy data:

import pandas as pd 
import numpy as np

segments = [str(i) for i in range(1_000)]

nums = np.random.choice(segments, (100_000,10))

df = pd.DataFrame({'segments': [','.join(n) for n in nums]})

This code generates a DataFrame with 100k rows.

Dummy data

The goal is to cast this to a 100,000 x 1,000 dataframe that one-hot-encodes the segment membership as expressed in the segments column.

V1. The naive solution [Speed x1 base case, 68s]

The most straight forward way to solve the problem is with a for loop:

segment_memberships = df['segments'].str.split(',')

for segment in segments:
df[segment] = segment_memberships.str.contains(segment)

This code works, but at 68 seconds it is a little slow. We can do much better.

V2. Slight improvement using sets [Speed x1.08 base, 64s]

One obvious observation is that we are doing ~100,000 x 1,000 membership checks. Whenever doing this, a list can be sub-optimal, and the use of sets comes to mind.

segment_memberships = df['segments'].str.split(',').apply(set)

for segment in segments:
df[segment] = segment_memberships.str.contains(segment)

The difference is nearly trivial in this case, mainly because each list is only 10 elements long, but it’s a helpful building block for the next solution.

V3. Using a lambda [Speed x5.3 base, 12.8s]

This is a small improvement. Just replacing the .str.contains with a use of a lambda and the in operator.

segment_memberships = df['segments'].str.split(',').apply(set)

for segment in segments:
df[segment] = segment_memberships.apply(lambda x: segment in x)

However, we can do better.

V4. Using the vectorized ≤ operator [Speed x8.25 base, 8.3s]

Ok, now for the black magic. Python has a number of “fast” code paths that give you near-magical performance improvements. In this case we can avoid the lambda and resort to the ≤ comparison operator.

Note that

segment in set(segments)

is equivalent to

set(segment) <= set(segments)

This lets us write:

segment_memberships = df['segments'].str.split(',').apply(set)

for segment in segments:
df[segment] = segment_memberships >= {segment}

For a total speedup of 8x the base case (8.3 seconds).

This roughly bottoms out what we can do with this for-loop approach. Time to look for other efficiencies.

V5. The way pandas would want us to do it [Speed 18.6x base, 3.7s]

The developers of Pandas have a solution for this problem, the .explode() method.

This expands the list of segments into distinct rows, which we can then pivot to get our one-hot-encoded dataframe.

df_ans = (df['segments']
.str.split(',')
.explode()
.reset_index()
.assign(__one__=1)
.pivot_table(index='index',
columns='segments',
values='__one__',
fill_value=0)
)

Sadly, this is more or less where the pandas speedups end. And to be fair, if you want to stop here that’s totally fine.

However, I’ve recently come across Polars, a very cool, very fast dataframe library.

There is some overhead in switching from pandas to Polars and back, but overall…

V6. Refactored using Polars [Speed 154x base, 440ms]

Now, I really do like Polars (named after polar bears). I don’t think teams should throw out all their pandas expertise and refactor it all into Polars… but it is tempting.

The logic in Polars is the same as that V5, it just has a slightly different syntax.

import polars as pl 

df_ans = (df
.pipe(pl.from_pandas) # From Pandas into Polars
.select(pl.col('segments').str.split(','))
.with_row_count(name='index') # Polars doesn't have indexes, so I create one using the row number.
.with_column(pl.lit(1).alias('__one__')) # Create a column of 1's for the pivot
.explode('segments') # Same as pandas's explode
.pivot(index='index',
columns='segments',
values='__one__',
aggregate_fn='sum') # A pivot, like in Pandas's .pivot_table
.fill_null(0) # Fill the missing values with 0
.to_pandas() # From Polars to Pandas
)

On my computer, the code above ran in 443ms, that is 154x faster than the original pandas-based solution, and 19x faster than the best Pandas-based solution (V5).

Moreover, note that the code above includes a conversion to and from pandas. Without it the code runs in ~180ms, for a final vanity speed-up number of 380x!

So, where did we get to?

This article covered a few tips and tricks of dataframe manipulations.

We saw how a simple refactoring of the naive solution can give an 8x speed up by leveraging vectorization of the ≤ operator.

We saw how good knowledge of the Pandas api can yield a further 2x speedup.

And we saw how Polars (even with the conversion from pandas and back) can deliver ridiculous speedups.

And does this matter? Yes!

The sample code and numbers here were with a 100,000 row dataframe. At this size, all the solutions are roughly equally valid. But real data can be much, much bigger.

In my friends’s case we are looking at a ~50m row dataframe that was part of a daily batch job.

I’m a bit lazy to run this at that scale but if we assume linear scaling:

  • V1: ~10h => ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
  • V2: ~9h ==> ||||||||||||||||||||||||||||||||||||||||||||||||||||||
  • V3: ~02h ==> ||||||||||||
  • V4: ~1.3h => |||||||
  • V5: ~50m => ||||
  • V6: ~4m => |

The benefits get bigger and bigger the more you think about it.

  • A 10h job requires monitoring, maybe check-pointing, alerting, etc…
  • If it fails and needs to be restarted the data might be late to downstream processes (in addition to the expected 10h delay!)
  • If the data volumes grow, you will soon be pushed into a multi-machine setup, or spark, dask, or some other bigger data solution.

However, At 4min, you can pretty much re-try the code at minimal cost. And that sparks joy!

A little up-front craftsmanship can save a lot of time down the road.

If you have a code challenge, email me at mkleinbort@gmail.com

--

--

Mycchaka Kleinbort
Mycchaka Kleinbort

No responses yet