Optimizing Memory Usage of Scikit-Learn Models Using Succinct Tries

Optimizing Memory Usage of Scikit-Learn Models Using Succinct Tries

We use the scikit-learn library for various machine-learning tasks at Scrapinghub. For example, for text classification we’d typically build a statistical model using sklearn’s Pipeline, FeatureUnion, some classifier (e.g. LinearSVC) + feature extraction and preprocessing classes. The model is usually trained on a developers machine, then serialized (using pickle/joblib) and uploaded to a server where the classification takes place.

Sometimes there can be too little available memory on the server for the classifier. One way to address this is to change the model: use simpler features, do feature selection, change the classifier to a less memory intensive one, use simpler preprocessing steps, etc. It usually means trading accuracy for better memory usage.

For text it is often CountVectorizer or TfidfVectorizer that consume most memory. For the last few months we have been using a trick to make them much more memory efficient in production (50x+) without changing anything from statistical point of view – this is what this article is about.

Let’s start with the basics. Most machine learning algorithms expect fixed size numeric feature vectors, so text should be converted to this format. Scikit-learn provides CountVectorizer, TfidfVectorizer and HashingVectorizer for text feature extraction (see the scikit-learn docs for more info).

CountVectorizer.transform converts a collection of text documents into a matrix of token counts. The counts matrix has a column for each known token and a row for each document; the value is a number of occurrences of a token in a document.

To create the counts matrix CountVectorizer must know which column corresponds to which token. The CountVectorizer.fit method basically remembers all tokens from some collection of documents and stores them in a “vocabulary”. Vocabulary is a Python dictionary: keys are tokens (or n-grams) and values are integer ids (column indices) ranging from 0 to len(vocabulary)-1.

Storing such a vocabulary in a standard Python dict is problematic; it can take a lot of memory even on relatively small data.

Let’s try it! Let’s use the “20 newsgroups” dataset available in scikit-learn. The “train” subset of this dataset has about 11k short documents (average document size is about 2KB, or 300 tokens; there are 130k unique tokens; average token length is 6.5).

Create and persist CountVectorizer:

from sklearn import datasets
from sklearn.externals import joblib

newsgroups_train = datasets.fetch_20newsgroups(subset='train')
vec = CountVectorizer()
joblib.dump(vec, 'vec_count.joblib')

Load and use it:

from sklearn.externals import joblib
vec = joblib.load('vec_count.joblib')
X = vec.transform(['the dog barks'])

On my machine, the loaded vectorizer uses about 82MB of memory in this case. If we add bigrams (by using CountVectorizer(ngram_range=(1,2))) then it would take about 650MB – and this is for a corpus that is quite small.

There are only 130k unique tokens; it’ll require less than 1MB to store these tokens in a plain text file ((6.5+1) * 130k). Maybe add an another megabyte to store column indices if they are not implicit (130k * 8). So the data itself should take only a couple of MBs. We may also have to somehow enumerate tokens and enable fast O(1) access to data, so there would be an overhead, but it shouldn’t take 80+MB – we’d expect 5-10MB at most. The serialized version of our CountVectorizer takes about 6MB on disk without any compression, but it expands to 80+MB when loaded to memory.

Why does it happen? There are two main reasons:

  1. Python objects are created for numbers (column indices) and strings (tokens). Each Python object has a pointer to its type + a reference counter (=> +16 bytes overhead per object on 64bit systems); for strings there are extra fields: length, hash, pointer to the string data, flags, etc. (the string representation is different in Python < 3.3 and Python 3.3+).
  2. Python dict is a hash table and introduces overheads – you have to store hash table itself, pointers to keys and values, etc. There is a great talk on Python dict implementation by Brandon Rhodes, check it if you’re interested in knowing more

Storing static string->id mapping in a hash table is not the most efficient way to do it: there are perfect hashes, tries, etc.; add Python objects overhead and here we are.

So I decided to try an alternative storage for vocabulary. MARISA-Trie (via Python wrapper) looked like a suitable data structure, as it:

  • is a heavily optimized succinct trie-like data structure, so it compresses string data well
  • provides a unique id for each key for free, and this id is in range from 0 to len(vocabulary)-1 – we don’t have to store these indices ourselves
  • only creates Python objects (strings, integers) on demand.

MARISA-Trie is not a general replacement for dict: you can’t add a key after building, it requires more time and memory to build, lookups (via Python wrapper) are slower – about 10x slower than dict’s, and it works best for “meaningful” string keys which have common parts (not for some random data).

I must admit I don’t fully understand how MARISA-Tries work 🙂 The implementation is available in a folder named “grimoire“, and the only information about the implementation I could find is Japanese slides which are outdated (as library author Susumu Yata says). It seems to be a succinct implementation of Patricia-Trie which can store references to other MARISA-Tries in addition to text data; this allows it to compress more than just prefixes (as in “standard” tries). “Succinct” means the Trie is encoded as a bit array.

You may never heard of this library, but if you have a recent Android phone it is likely MARISA-Trie is in your pocket – a copy of marisa-trie is in the Android 4.3+ source tree.

Ok, great, but we have to tell scikit-learn to use this data structure instead of a dict for vocabulary storage.

Scikit-learn allows passing a custom vocabulary (a dict-like object) to CountVectorizer. But this won’t help us because MARISA-Trie is not exactly dict-like; it can’t be built and modified like dict. CountVectorizer should build a vocabulary for us (using its tokenization and preprocessing features) and only then we may “freeze” it to a compact representation.

At first, we were doing it using a hack. fit and fit_transform methods were overridden: first, they call the parent method to build a vocabulary, then they freeze that vocabulary (i.e. build a MARISA-Trie from it) and trick CountVectorizer to think a fixed vocabulary was passed to the constructor, and then parents method is called once more. Calling fit/fit_transform twice is necessary because the indices learned on the first call and indices in the frozen vocabulary are different. This quick & dirty implementation is here, and this is what we’re using in production.

I recently improved it and removed this “call fit/fit_transform twice” hack for CountVectorizer, but we haven’t used this implementation yet. See https://gist.github.com/kmike/9750796.

The results? For the same dataset, MarisaCountVectorizer uses about 0.9MB for unigrams (instead of 82MB) and about 13.3MB for unigrams+bigrams (instead of 650MB+). This is a 50-90x reduction of memory usage. Tada!


The downside is that MarisaCountVectorizer.fit and MarisaCountVectorizer.fit_transform methods are 10-30% slower than CountVectorizer’s (new version; old version was up to 2x+ slower).




  • CountVectorizer(): 3.6s fit, 5.3s dump, 1.9s transform
  • MarisaCountVectorizer(), new version: 3.9s fit, 0s dump, 2.5s transform
  • MarisaCountVectorizer(), old version: 7.5s fit, 0s dump, 2.6s transform
  • CountVectorizer(ngram_range=(1,2)): 15.2s fit, 52.0s dump, 5.3s transform
  • MarisaCountVectorizer(ngram_range=(1,2)), new version: 18.7s fit, 0.0s dump, 6,8s transform
  • MarisaCountVectorizer(ngram_range=(1,2)), old version: 28.3s fit, 0.0s dump, 6.8s transform

‘fit’ method was executed on ‘train’ subset of ’20 newsgroups’ dataset; ‘transform’ method was executed on ‘test’ subset.

marisa-trie stores all data in a contignuous memory block so saving it to disk and loading it from disk is much faster than saving/loading a Python dict serialized using pickle.

Serialized file sizes (uncompressed):

  • CountVectorizer(): 5.9MB
  • MarisaCountVectorizer(): 371KB
  • CountVectorizer(ngram_range=(1,2)): 59MB
  • MarisaCountVectorizer(ngram_range=(1,2)): 3.8MB

TfidfVectorizer is implemented on top of CountVectorizer; it could also benefit from more efficient storage for vocabulary. I tried it, and for MarisaTfidfVectorizer the results are similar. It is possible to optimize DictVectorizer as well.

Note that MARISA-based vectorizers don’t help with memory usage during training. They may help with memory usage when saving models to disk though – pickle allocates big chunks of memory when saving Python dicts.

So when memory usage is an issue, ditch scikit-learn standard vectorizers and use marisa-based variants? Not so fast: don’t forget about HashingVectorizer. It has a number of benefits. Check the docs: HashingVectorizer doesn’t need a vocabulary so it fits and serializes in no time and it is very memory efficient because it is stateless.

As always, there are some tradeoffs:

  • HashingVectorizer.transform is irreversable (you can’t check which tokens are active) so it is harder to inspect what a classifer has learned from text data.
  • There could be collisions, and with improper n_features it could affect the prediction quality of a classifier.
  • A related disadvantage is that the resulting feature vectors are larger than the feature vectors produced by other vectorizers unless we allow collisions. The HashingVectorizer.transform result is not useful by itself, it is usually passed to the next step (classifier or something like PCA), and a larger input dimension could mean that this subsequent step will take more memory and will be slower to save/load, so the memory savings of HashingVectorizer could be compensated by increased memory usage of subsequent steps.
  • HashingVectorizer can’t limit features based on document frequency (min_df and max_df options are not supported).

Of course, all vectorizers have their own advantages and disadvantages, and there are use cases for all of them. You can use e.g. CountVectorizer for development and switch to HashingVectorizer for production, avoiding some of HashingVectorizer downsides. Also, don’t forget about feature selection and other similar techniques. Using succinct Trie-based vectorizers is not the only way to reduce memory usage, and often it is not the best way, but sometimes they are useful; being a drop-in replacement for CountVectorizer and TfidfVectorizer helps.

In our recent project, min_df > 1 was crucial for removing noisy features. Vocabulary wasn’t the only thing that used memory; MarisaTfidfVectorizer instead of TfidfVectorizer (+ MarisaCountVectorizer instead of CountVectorizer) decreased the total classifier memory consumption by about 30%. It is not a brilliant 50x-80x, but it made the difference between “classifier fits into memory” and “classifier doesn’t fit into memory”.

Some links:

There is a ticket to discuss efficient vocabulary storage with scikit-learn developers. Once the discussion settles our plan is to make a PR to scikit-learn to make using such vectorizers easier and/or release an open-source package with MarisaCountVectorizer & friends – stay tuned!

Be the first to know. Gain insights. Make better decisions.

Use web data to do all this and more. We’ve been crawling the web since 2010 and can provide you with web data as a service.

Tell me more

21 thoughts on “Optimizing Memory Usage of Scikit-Learn Models Using Succinct Tries

  1. You should also be able to play cool data structure tricks to keep the count of features that occurred more than once to get the list of features with min_df > 1 efficiently. For example, if you want to just get df > 1 (rather than a generic count), you can do it probabilistically with a bloom filter, where you first check if an item exists, then insert it to the bloom filter. You only let it pass into the vocab if it already exists before you insert.

  2. Nice post, thanks for the idea! I’ve tried to implement it a little differently, avoiding the expensive call to the CountVectorizer.fit_transform() method. The code is here: https://gist.github.com/rokroskar/89d85334f565a25a9960#file-marisa_vectorizer-py

    I’m not able to reproduce your numbers for your implementation of MarisaCountVectorizer — it always runs considerably slower and uses more memory than the sklearn CountVectorizer. It’s not clear to me how it could actually save on memory use if it calls the same fit_transform() method making the same expensive dict?

    1. Hi,

      Memory savings mentioned in the article are for prediction time, not for training time: you fit CountVectorizer, save it to disk (maybe with a model that uses it), then load (e.g. on the production server) – and it uses much less memory than before. As you said, fitting is slower and uses more memory; the upside is that serialization is much faster and uses less memory, deserialization is faster and transform() requires much less memory.

      Nice idea about using marisa-trie during fit_transform() directly. It won’t always work though: to support min_df / max_df / max_features parameters you need to delete keys from the mapping which is not possible with marisa-trie. But more memory efficient fit() may be helpful sometimes.

      1. aha I see — as long as the initial fit fits into memory you’re fine then.

        yes this is not a final/best solution — I was toying with the idea of using your datrie implementation instead so that removal would be possible, but inserting elements seems to take a really long time. Perhaps building a new vocabulary followed by a remapping of the transformation vector would be the way to go if sticking with the marisa-trie.

      2. Yes, datrie is going to be very slow at build time because it needs to move large chunks of memory.

        Unfortunately I haven’t got to submit a PR to scikit-learn yet, but in https://github.com/scikit-learn/scikit-learn/issues/2639 there are results for hat-trie, std::map and std::unordered_map (code: https://gist.github.com/kmike/9819115); they all can make vocabulary building more memory efficient at the cost of speed penalty. Memory savings are substantial, but nothing near marisa-trie savings at prediction time.

  3. Nice post!

    On a somewhat related note: regarding the use of memory while training on the output of CountVectorizer or HashingVectorizer, several sklearn classifiers and regressors provide a method called partial_fit(), which allows training on chunks one by one instead of loading all training rows into memory at the same time, which is very useful when trying to train on a large data set.

  4. The function “_freeze_vocabulary” would raise “AttributeError: can’t set attribute”.


    The attribute “fixed_vocabulary” is deprecated in recent versions of CountVectorizer (refer to sklearn docs) and replaced by “fixed_vocabulary_”

    In both implementations of the MarisaCountVectorizer, in the function “_freeze_vocabulary”, the following line

    self.fixed_vocabulary = True

    should be replaced by this line

    self.fixed_vocabulary_ = True

    Please, let me know if my observation is wrong.

    sklearn docs:

  5. An alternate approach is to:
    1. Use Count/TfIdf Vectorizers as usual in a pipeline with or without grid search.
    2. Once the best estimator is known, then get the feature names from the vectorizer http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html#sklearn.feature_extraction.text.CountVectorizer.get_feature_names
    3. Train a new classifier with the vectorizer initiazed with the above featurs names as explicit vocabulary

    Was able to reduce 150MB to under 1MB

  6. I agree with Santhosh. The CountVectorizer saves except from the vocabulary, a mapping of all the training examples, while you only need to know the vocabulary of the CountVectorizer for classifying new instances.

    For this you can create a new CountVectorizer instance with only the vocubalary of the old one:

    new_count_vect = CountVectorizer(vocubalary=old_count_vect.vocubalary_)

    Then you can save and load this new CountVectorizer to transform new instances to the desired state:


    The new CountVectorizer is only a small fraction (1% in my case) of the old CountVectorizer. For more info see: http://stackoverflow.com/a/22920964/1970722

  7. @Mikhail I have tested the ‘MARISA-Trie based vectorizers’ using “transforming” Script.I am not getting the desired result.Still not able to convince myself whether this modified vectorizer is taking less memory as compared to TFIDF. MY OS is Windows 8.I am using Python 2.7.8.Please let me know where I am going wrong

    1. Hey @Nitin – the scripts are two years old, and relevant scikit-learn parts were refactored, so I’m not 100% sure they work now as-is. But it is hard to tell what’s the problem in your case without more details about what you’re doing exactly.

      1. Thanks for the reply.After saving your “MARISA-Trie based vectorizers” python script ,I decided to use Marisa Trie based Vectorizer.Following is the code that I wrote to compare Marisa VS TFIDF Vectorizer.

        Code for creating the vectorizers

        newsgroups_train = datasets.fetch_20newsgroups(subset=’train’)
        vec = TfidfVectorizer(ngram_range=(1,2))
        vec2 = MarisaCountVectorizer(ngram_range=(1,2))
        joblib.dump(vec, ‘vec_tfidf.joblib.p’)
        joblib.dump(vec2, ‘vec_Marisa.joblib.p’)

        Code for checking the memory of these Vectorizers

        newsgroups_test = datasets.fetch_20newsgroups(subset=’test’)
        p = psutil.Process(os.getpid())
        before = p.get_memory_info().rss / 2**20
        print before
        vec = joblib.load(‘vec_Marisa.joblib.p’)
        after_load = p.get_memory_info().rss / 2**20
        print after_load
        X = vec.transform(newsgroups_test)
        after_load = p.get_memory_info().rss / 2**20
        print after_load

        Please let me know if I am doing any wrong

        Many Thanks in Advance 🙂

  8. In the above comment I have written the memory check code for Marisa Vectorizer.Similarly I have written for TFIDF Vectorizer.But TFIDF memory result are not too different from Marisa Vectorizer.Please let me know what you think


    1. Yeah, you’re right – the problem is that code needs to be modified to account for scikit-learn changes. If you try it with scikit-learn 0.14.1 it should give proper results, but with scikit-learn 0.15+ it does nothing useful.

  9. Thank you very much for the post!!! It helped me alot.

    Because the HashingVectorizer is stateless, there is no need to save/dump it to disk, right? Since there is no fitting, only transforms.

    Thanks again. Cheers!!!

    1. Hey,

      You’re right that there is no fitting in HashingVectorizer, but usually it is still pickled and save to disk, because of 2 reasons:

      1. It is convenient to do training once, then dump the whole scikit-learn processing pipeline (e.g. implemented using http://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html) to disk, and then load it (maybe on another machine) to use for predictions. HashingVectorizer is usually a part of pipeline, so it is also dumped.

      2. At prediction time you need to use HashingVectorizer with the same parameters (n_features, ngram_range, etc) as in training to get correct results; saving HashingVectorizer instance (including saving it as a part of pipeline) ensures you won’t create the vectorizer with incompatible parameters.

    2. My text classification pipeline with HashingVectorizer performed worse than with CountVectorizer(binary=True). Although, I do not know why. Does anybody have same experience?

      1. Hey,

        It is hard to tell, but maybe n_features is too low (and so there are too many collisions) , or maybe other parameters don’t match (e.g. are you using binary=False for HashingVectorizer as well?).

Leave a Reply

Your email address will not be published. Required fields are marked *