The Algorithms logoThe Algorithms
About

Named Entity Recognition with Conditional Random Fields

K

Named Entity Recognition with Conditional Random Fields

One of the classic challenges of Natural Language Processing is sequence labelling. In sequence labelling, the goal is to label each word in a text with a word class. In part-of-speech tagging, these word classes are parts of speech, such as noun or verb. In named entity recognition (NER), they're types of generic named entities, such as locations, people or organizations, or more specialized entities, such as diseases or symptoms in the healthcare domain. In this way, sequence labelling can help us extract the most important information from a text and improve the performance of analytics, search or matching applications.

In this notebook we'll explore Conditional Random Fields, the most popular approach to sequence labelling before Deep Learning arrived. Deep Learning may get all the attention right now, but Conditional Random Fields are still a powerful tool to build a simple sequence labeller.

The tool we're going to use is sklearn-crfsuite. This is a wrapper around python-crfsuite, which itself is a Python binding of CRFSuite. The reason we're using sklearn-crfsuite is that it provides a number of handy utility functions, for example for evaluating the output of the model. You can install it with pip install sklearn-crfsuite.

Data

First we get some data. A well-known data set for training and testing NER models is the CoNLL-2002 data, which has Spanish and Dutch texts labelled with four types of entities: locations (LOC), persons (PER), organizations (ORG) and miscellaneous entities (MISC). Both corpora are split up in three portions: a training portion and two smaller test portions, one of which we'll use as development data. It's easy to collect the data from NLTK.

import nltk
import sklearn
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelBinarizer
import sklearn_crfsuite as crfsuite
from sklearn_crfsuite import metrics
train_sents = list(nltk.corpus.conll2002.iob_sents('ned.train'))
dev_sents = list(nltk.corpus.conll2002.iob_sents('ned.testa'))
test_sents = list(nltk.corpus.conll2002.iob_sents('ned.testb'))

The data consists of a list of tokenized sentences. For each of the tokens we have the string itself, its part-of-speech tag and its entity tag, which follows the BIO convention. In the deep learning world we live in today, it's common to ignore the part-of-speech tags. However, since CRFs rely on good feature extraction, we'll gladly make use of this information. After all, the part of speech of a word tells us a lot about its possible status as a named entity: nouns will more often be entities than verbs, for example.

train_sents[0]
[('De', 'Art', 'O'),
 ('tekst', 'N', 'O'),
 ('van', 'Prep', 'O'),
 ('het', 'Art', 'O'),
 ('arrest', 'N', 'O'),
 ('is', 'V', 'O'),
 ('nog', 'Adv', 'O'),
 ('niet', 'Adv', 'O'),
 ('schriftelijk', 'Adj', 'O'),
 ('beschikbaar', 'Adj', 'O'),
 ('maar', 'Conj', 'O'),
 ('het', 'Art', 'O'),
 ('bericht', 'N', 'O'),
 ('werd', 'V', 'O'),
 ('alvast', 'Adv', 'O'),
 ('bekendgemaakt', 'V', 'O'),
 ('door', 'Prep', 'O'),
 ('een', 'Art', 'O'),
 ('communicatiebureau', 'N', 'O'),
 ('dat', 'Conj', 'O'),
 ('Floralux', 'N', 'B-ORG'),
 ('inhuurde', 'V', 'O'),
 ('.', 'Punc', 'O')]

Feature Extraction

Whereas today neural networks are expected to learn the relevant features of the input texts themselves, this is very different with Conditional Random Fields. CRFs learn the relationship between the features we give them and the label of a token in a given context. They're not going to earn these features themselves. Instead, the quality of the model will depend highly on the relevance of the features we show it.

The most important method in this tutorial is therefore the one that collects the features for every token. What information could be useful? The word itself, of course, together with its part of speech tag. It can also be interesting to know whether the word is completely uppercase, whether it starts with a capital or is a digit. In addition, we also take a look at the character bigram and trigram the word ends with. We also give every token a bias feature, which always has the same value. This bias feature helps the CRF learn the relative frequency of each label type in the training data.

To give the CRF more information about the meaning of a word, we also introduce information from word embeddings. In our Word Embedding notebook, we trained word embeddings on Dutch Wikipedia and clustered them in 500 clusters. Here we'll read these 500 clusters from a file, and map each word to the id of the cluster it is in. This is really useful for Named Entity Recognition, as most entity types cluster together. This allows CRFs to generalize above the word level. For example, when the CRF encounters a word it has never seen (say, Albania), it can base its decision on the cluster the word is in. If this cluster contains many other entities the CRF has met in its training data (say, Italy, Germany and France), it will have learnt a string link between this cluster and a specific entity type. As a result, it can still assign that entity type to the unknown word. In our experiments, this feature alone boosts the performance with around 3%.

Finally, apart from the token itself, we also want the CRF to look at its context. More specifically, we're going to give it some extra information about the two words to the left and the right of the targt word. We'll tell the CRF what these words are, whether they start with a capital or are completely uppercase, and give it their part-of-speech tag. If there is no left or right context, we'll inform the CRF that the token is at the beginning or end of the sentence (BOS or EOS).

def read_clusters(cluster_file):
    word2cluster = {}
    with open(cluster_file) as i:
        for line in i:
            word, cluster = line.strip().split('\t')
            word2cluster[word] = cluster
    return word2cluster


def word2features(sent, i, word2cluster):
    word = sent[i][0]
    postag = sent[i][1]
    features = [
        'bias',
        'word.lower=' + word.lower(),
        'word[-3:]=' + word[-3:],
        'word[-2:]=' + word[-2:],
        'word.isupper=%s' % word.isupper(),
        'word.istitle=%s' % word.istitle(),
        'word.isdigit=%s' % word.isdigit(),
        'word.cluster=%s' % word2cluster[word.lower()] if word.lower() in word2cluster else "0",
        'postag=' + postag
    ]
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.extend([
            '-1:word.lower=' + word1.lower(),
            '-1:word.istitle=%s' % word1.istitle(),
            '-1:word.isupper=%s' % word1.isupper(),
            '-1:postag=' + postag1
        ])
    else:
        features.append('BOS')

    if i > 1: 
        word2 = sent[i-2][0]
        postag2 = sent[i-2][1]
        features.extend([
            '-2:word.lower=' + word2.lower(),
            '-2:word.istitle=%s' % word2.istitle(),
            '-2:word.isupper=%s' % word2.isupper(),
            '-2:postag=' + postag2
        ])        

        
    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.extend([
            '+1:word.lower=' + word1.lower(),
            '+1:word.istitle=%s' % word1.istitle(),
            '+1:word.isupper=%s' % word1.isupper(),
            '+1:postag=' + postag1
        ])
    else:
        features.append('EOS')

    if i < len(sent)-2:
        word2 = sent[i+2][0]
        postag2 = sent[i+2][1]
        features.extend([
            '+2:word.lower=' + word2.lower(),
            '+2:word.istitle=%s' % word2.istitle(),
            '+2:word.isupper=%s' % word2.isupper(),
            '+2:postag=' + postag2
        ])

        
    return features


def sent2features(sent, word2cluster):
    return [word2features(sent, i, word2cluster) for i in range(len(sent))]

def sent2labels(sent):
    return [label for token, postag, label in sent]

def sent2tokens(sent):
    return [token for token, postag, label in sent]

word2cluster = read_clusters("data/embeddings/clusters_nl.tsv")
sent2features(train_sents[0], word2cluster)[0]
['bias',
 'word.lower=de',
 'word[-3:]=De',
 'word[-2:]=De',
 'word.isupper=False',
 'word.istitle=True',
 'word.isdigit=False',
 'word.cluster=38',
 'postag=Art',
 'BOS',
 '+1:word.lower=tekst',
 '+1:word.istitle=False',
 '+1:word.isupper=False',
 '+1:postag=N',
 '+2:word.lower=van',
 '+2:word.istitle=False',
 '+2:word.isupper=False',
 '+2:postag=Prep']
X_train = [sent2features(s, word2cluster) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]

X_dev = [sent2features(s, word2cluster) for s in dev_sents]
y_dev = [sent2labels(s) for s in dev_sents]

X_test = [sent2features(s, word2cluster) for s in test_sents]
y_test = [sent2labels(s) for s in test_sents]

Training

We now create a CRF model and train it. We'll use the standard L-BFGS algorithm for our parameter estimation and run it for 100 iterations. When we're done, we save the model with joblib.

crf = crfsuite.CRF(
    verbose='true',
    algorithm='lbfgs',
    max_iterations=100
)

crf.fit(X_train, y_train, X_dev=X_dev, y_dev=y_dev)
loading training data to CRFsuite: 100%|██████████| 15806/15806 [00:02<00:00, 7623.17it/s]
loading dev data to CRFsuite:  27%|██▋       | 769/2895 [00:00<00:00, 7689.13it/s]
loading dev data to CRFsuite: 100%|██████████| 2895/2895 [00:00<00:00, 7186.08it/s]
Holdout group: 2

Feature generation
type: CRF1d
feature.minfreq: 0.000000
feature.possible_states: 0
feature.possible_transitions: 0
0....1....2....3....4....5....6....7....8....9....10
Number of features: 152117
Seconds required: 0.424

L-BFGS optimization
c1: 0.000000
c2: 1.000000
num_memories: 6
max_iterations: 100
epsilon: 0.000010
stop: 10
delta: 0.000010
linesearch: MoreThuente
linesearch.max_iterations: 20

Iter 1   time=0.37  loss=104214.83 active=152117 precision=0.100  recall=0.111  F1=0.105  Acc(item/seq)=0.901 0.496  feature_norm=1.00
Iter 2   time=0.21  loss=96997.81 active=152117 precision=0.100  recall=0.111  F1=0.105  Acc(item/seq)=0.901 0.496  feature_norm=1.13
Iter 3   time=0.21  loss=92085.38 active=152117 precision=0.100  recall=0.111  F1=0.105  Acc(item/seq)=0.901 0.496  feature_norm=1.26
Iter 4   time=0.21  loss=84277.67 active=152117 precision=0.100  recall=0.111  F1=0.105  Acc(item/seq)=0.901 0.496  feature_norm=1.51
Iter 5   time=0.21  loss=67577.53 active=152117 precision=0.169  recall=0.113  F1=0.109  Acc(item/seq)=0.902 0.496  feature_norm=2.32
Iter 6   time=0.21  loss=47854.26 active=152117 precision=0.326  recall=0.347  F1=0.320  Acc(item/seq)=0.930 0.580  feature_norm=4.34
Iter 7   time=0.21  loss=43326.19 active=152117 precision=0.340  recall=0.365  F1=0.333  Acc(item/seq)=0.933 0.592  feature_norm=5.06
Iter 8   time=0.21  loss=38617.07 active=152117 precision=0.372  recall=0.399  F1=0.362  Acc(item/seq)=0.938 0.618  feature_norm=6.43
Iter 9   time=0.21  loss=35511.85 active=152117 precision=0.491  recall=0.442  F1=0.421  Acc(item/seq)=0.942 0.631  feature_norm=8.62
Iter 10  time=0.21  loss=32735.31 active=152117 precision=0.535  recall=0.456  F1=0.445  Acc(item/seq)=0.944 0.654  feature_norm=9.50
Iter 11  time=0.21  loss=31687.17 active=152117 precision=0.530  recall=0.491  F1=0.502  Acc(item/seq)=0.947 0.663  feature_norm=10.94
Iter 12  time=0.21  loss=29323.19 active=152117 precision=0.532  recall=0.482  F1=0.477  Acc(item/seq)=0.946 0.666  feature_norm=11.18
Iter 13  time=0.21  loss=28733.55 active=152117 precision=0.596  recall=0.489  F1=0.489  Acc(item/seq)=0.947 0.668  feature_norm=11.58
Iter 14  time=0.21  loss=27120.69 active=152117 precision=0.640  recall=0.507  F1=0.512  Acc(item/seq)=0.948 0.677  feature_norm=12.36
Iter 15  time=0.21  loss=24849.05 active=152117 precision=0.640  recall=0.547  F1=0.558  Acc(item/seq)=0.952 0.697  feature_norm=13.86
Iter 16  time=0.40  loss=24033.40 active=152117 precision=0.654  recall=0.580  F1=0.586  Acc(item/seq)=0.954 0.706  feature_norm=14.53
Iter 17  time=0.21  loss=22935.94 active=152117 precision=0.669  recall=0.578  F1=0.598  Acc(item/seq)=0.955 0.712  feature_norm=15.15
Iter 18  time=0.21  loss=21803.53 active=152117 precision=0.682  recall=0.584  F1=0.605  Acc(item/seq)=0.956 0.713  feature_norm=15.67
Iter 19  time=0.21  loss=21046.75 active=152117 precision=0.725  recall=0.565  F1=0.586  Acc(item/seq)=0.956 0.717  feature_norm=16.04
Iter 20  time=0.21  loss=20465.96 active=152117 precision=0.701  recall=0.566  F1=0.594  Acc(item/seq)=0.956 0.711  feature_norm=15.81
Iter 21  time=0.21  loss=19991.29 active=152117 precision=0.706  recall=0.586  F1=0.611  Acc(item/seq)=0.957 0.717  feature_norm=15.63
Iter 22  time=0.21  loss=19560.23 active=152117 precision=0.684  recall=0.597  F1=0.616  Acc(item/seq)=0.957 0.722  feature_norm=15.58
Iter 23  time=0.21  loss=19241.14 active=152117 precision=0.672  recall=0.602  F1=0.615  Acc(item/seq)=0.957 0.726  feature_norm=15.65
Iter 24  time=0.21  loss=18787.87 active=152117 precision=0.678  recall=0.627  F1=0.637  Acc(item/seq)=0.958 0.731  feature_norm=16.31
Iter 25  time=0.21  loss=18145.13 active=152117 precision=0.690  recall=0.625  F1=0.640  Acc(item/seq)=0.959 0.734  feature_norm=16.94
Iter 26  time=0.21  loss=17786.38 active=152117 precision=0.710  recall=0.621  F1=0.642  Acc(item/seq)=0.959 0.738  feature_norm=17.48
Iter 27  time=0.21  loss=17247.02 active=152117 precision=0.711  recall=0.625  F1=0.649  Acc(item/seq)=0.959 0.736  feature_norm=18.46
Iter 28  time=0.21  loss=16876.01 active=152117 precision=0.737  recall=0.627  F1=0.655  Acc(item/seq)=0.960 0.748  feature_norm=19.83
Iter 29  time=0.21  loss=16543.87 active=152117 precision=0.732  recall=0.631  F1=0.663  Acc(item/seq)=0.961 0.751  feature_norm=20.12
Iter 30  time=0.21  loss=16263.21 active=152117 precision=0.725  recall=0.644  F1=0.671  Acc(item/seq)=0.962 0.753  feature_norm=20.42
Iter 31  time=0.21  loss=15665.78 active=152117 precision=0.715  recall=0.661  F1=0.676  Acc(item/seq)=0.963 0.758  feature_norm=21.50
Iter 32  time=0.21  loss=15247.34 active=152117 precision=0.700  recall=0.641  F1=0.650  Acc(item/seq)=0.961 0.748  feature_norm=23.18
Iter 33  time=0.21  loss=14866.51 active=152117 precision=0.702  recall=0.658  F1=0.670  Acc(item/seq)=0.963 0.754  feature_norm=24.67
Iter 34  time=0.21  loss=14650.96 active=152117 precision=0.704  recall=0.659  F1=0.675  Acc(item/seq)=0.963 0.755  feature_norm=25.38
Iter 35  time=0.21  loss=14386.98 active=152117 precision=0.730  recall=0.677  F1=0.695  Acc(item/seq)=0.964 0.761  feature_norm=26.47
Iter 36  time=0.21  loss=14158.49 active=152117 precision=0.742  recall=0.681  F1=0.704  Acc(item/seq)=0.965 0.763  feature_norm=28.58
Iter 37  time=0.21  loss=13895.30 active=152117 precision=0.736  recall=0.684  F1=0.701  Acc(item/seq)=0.965 0.765  feature_norm=28.72
Iter 38  time=0.21  loss=13656.45 active=152117 precision=0.730  recall=0.683  F1=0.695  Acc(item/seq)=0.965 0.768  feature_norm=29.07
Iter 39  time=0.21  loss=13499.28 active=152117 precision=0.727  recall=0.680  F1=0.691  Acc(item/seq)=0.965 0.769  feature_norm=29.61
Iter 40  time=0.21  loss=13174.95 active=152117 precision=0.726  recall=0.677  F1=0.689  Acc(item/seq)=0.965 0.766  feature_norm=31.03
Iter 41  time=0.21  loss=13104.00 active=152117 precision=0.736  recall=0.662  F1=0.678  Acc(item/seq)=0.964 0.760  feature_norm=33.06
Iter 42  time=0.21  loss=12750.79 active=152117 precision=0.731  recall=0.685  F1=0.701  Acc(item/seq)=0.966 0.764  feature_norm=34.35
Iter 43  time=0.21  loss=12637.02 active=152117 precision=0.740  recall=0.690  F1=0.708  Acc(item/seq)=0.966 0.766  feature_norm=34.63
Iter 44  time=0.21  loss=12534.20 active=152117 precision=0.745  recall=0.692  F1=0.712  Acc(item/seq)=0.966 0.766  feature_norm=35.30
Iter 45  time=0.21  loss=12390.89 active=152117 precision=0.739  recall=0.682  F1=0.702  Acc(item/seq)=0.966 0.761  feature_norm=37.15
Iter 46  time=0.21  loss=12277.42 active=152117 precision=0.733  recall=0.689  F1=0.704  Acc(item/seq)=0.966 0.763  feature_norm=37.58
Iter 47  time=0.21  loss=12219.03 active=152117 precision=0.739  recall=0.690  F1=0.706  Acc(item/seq)=0.966 0.767  feature_norm=37.20
Iter 48  time=0.21  loss=12125.64 active=152117 precision=0.743  recall=0.695  F1=0.711  Acc(item/seq)=0.967 0.770  feature_norm=36.99
Iter 49  time=0.21  loss=11970.65 active=152117 precision=0.745  recall=0.697  F1=0.712  Acc(item/seq)=0.967 0.775  feature_norm=36.84
Iter 50  time=0.21  loss=11780.73 active=152117 precision=0.754  recall=0.696  F1=0.718  Acc(item/seq)=0.968 0.777  feature_norm=37.57
Iter 51  time=0.21  loss=11623.83 active=152117 precision=0.740  recall=0.691  F1=0.710  Acc(item/seq)=0.967 0.774  feature_norm=38.21
Iter 52  time=0.21  loss=11549.38 active=152117 precision=0.740  recall=0.690  F1=0.709  Acc(item/seq)=0.967 0.773  feature_norm=38.94
Iter 53  time=0.21  loss=11497.85 active=152117 precision=0.739  recall=0.696  F1=0.713  Acc(item/seq)=0.967 0.772  feature_norm=39.54
Iter 54  time=0.21  loss=11419.64 active=152117 precision=0.733  recall=0.692  F1=0.707  Acc(item/seq)=0.966 0.773  feature_norm=40.40
Iter 55  time=0.21  loss=11280.29 active=152117 precision=0.744  recall=0.707  F1=0.719  Acc(item/seq)=0.967 0.773  feature_norm=41.75
Iter 56  time=0.21  loss=11131.39 active=152117 precision=0.746  recall=0.710  F1=0.722  Acc(item/seq)=0.968 0.773  feature_norm=42.72
Iter 57  time=0.21  loss=11043.40 active=152117 precision=0.751  recall=0.713  F1=0.726  Acc(item/seq)=0.968 0.774  feature_norm=42.92
Iter 58  time=0.21  loss=10954.38 active=152117 precision=0.769  recall=0.713  F1=0.736  Acc(item/seq)=0.969 0.781  feature_norm=43.18
Iter 59  time=0.21  loss=10836.31 active=152117 precision=0.773  recall=0.713  F1=0.736  Acc(item/seq)=0.969 0.781  feature_norm=43.74
Iter 60  time=0.21  loss=10712.24 active=152117 precision=0.779  recall=0.719  F1=0.744  Acc(item/seq)=0.970 0.788  feature_norm=44.41
Iter 61  time=0.21  loss=10602.81 active=152117 precision=0.789  recall=0.709  F1=0.740  Acc(item/seq)=0.970 0.789  feature_norm=44.68
Iter 62  time=0.21  loss=10508.84 active=152117 precision=0.782  recall=0.711  F1=0.739  Acc(item/seq)=0.970 0.787  feature_norm=45.37
Iter 63  time=0.21  loss=10458.88 active=152117 precision=0.783  recall=0.717  F1=0.744  Acc(item/seq)=0.970 0.788  feature_norm=45.51
Iter 64  time=0.21  loss=10420.78 active=152117 precision=0.763  recall=0.711  F1=0.730  Acc(item/seq)=0.969 0.786  feature_norm=45.67
Iter 65  time=0.21  loss=10315.28 active=152117 precision=0.766  recall=0.721  F1=0.735  Acc(item/seq)=0.969 0.786  feature_norm=46.30
Iter 66  time=0.21  loss=10204.10 active=152117 precision=0.769  recall=0.728  F1=0.740  Acc(item/seq)=0.970 0.786  feature_norm=47.10
Iter 67  time=0.21  loss=10134.54 active=152117 precision=0.769  recall=0.716  F1=0.737  Acc(item/seq)=0.970 0.787  feature_norm=47.87
Iter 68  time=0.21  loss=10095.10 active=152117 precision=0.773  recall=0.718  F1=0.741  Acc(item/seq)=0.970 0.787  feature_norm=47.85
Iter 69  time=0.21  loss=10059.52 active=152117 precision=0.773  recall=0.715  F1=0.738  Acc(item/seq)=0.970 0.790  feature_norm=47.81
Iter 70  time=0.21  loss=10012.53 active=152117 precision=0.767  recall=0.712  F1=0.733  Acc(item/seq)=0.970 0.790  feature_norm=47.95
Iter 71  time=0.21  loss=9931.41  active=152117 precision=0.765  recall=0.710  F1=0.731  Acc(item/seq)=0.970 0.791  feature_norm=48.49
Iter 72  time=0.21  loss=9861.45  active=152117 precision=0.763  recall=0.709  F1=0.730  Acc(item/seq)=0.970 0.793  feature_norm=49.07
Iter 73  time=0.21  loss=9803.75  active=152117 precision=0.769  recall=0.723  F1=0.741  Acc(item/seq)=0.970 0.796  feature_norm=49.32
Iter 74  time=0.21  loss=9762.61  active=152117 precision=0.758  recall=0.720  F1=0.733  Acc(item/seq)=0.970 0.793  feature_norm=49.58
Iter 75  time=0.21  loss=9726.47  active=152117 precision=0.761  recall=0.722  F1=0.736  Acc(item/seq)=0.970 0.790  feature_norm=49.80
Iter 76  time=0.21  loss=9628.97  active=152117 precision=0.764  recall=0.722  F1=0.736  Acc(item/seq)=0.970 0.789  feature_norm=50.47
Iter 77  time=0.40  loss=9586.61  active=152117 precision=0.763  recall=0.725  F1=0.740  Acc(item/seq)=0.971 0.791  feature_norm=50.96
Iter 78  time=0.21  loss=9522.00  active=152117 precision=0.767  recall=0.723  F1=0.741  Acc(item/seq)=0.970 0.788  feature_norm=51.34
Iter 79  time=0.21  loss=9479.87  active=152117 precision=0.765  recall=0.720  F1=0.736  Acc(item/seq)=0.970 0.788  feature_norm=51.77
Iter 80  time=0.21  loss=9448.33  active=152117 precision=0.771  recall=0.720  F1=0.739  Acc(item/seq)=0.970 0.789  feature_norm=51.80
Iter 81  time=0.21  loss=9423.55  active=152117 precision=0.769  recall=0.723  F1=0.740  Acc(item/seq)=0.970 0.789  feature_norm=51.83
Iter 82  time=0.21  loss=9367.51  active=152117 precision=0.763  recall=0.720  F1=0.736  Acc(item/seq)=0.970 0.790  feature_norm=52.25
Iter 83  time=0.21  loss=9322.98  active=152117 precision=0.762  recall=0.722  F1=0.737  Acc(item/seq)=0.970 0.793  feature_norm=52.34
Iter 84  time=0.21  loss=9277.51  active=152117 precision=0.765  recall=0.725  F1=0.740  Acc(item/seq)=0.971 0.798  feature_norm=52.47
Iter 85  time=0.21  loss=9229.07  active=152117 precision=0.772  recall=0.726  F1=0.743  Acc(item/seq)=0.971 0.798  feature_norm=52.79
Iter 86  time=0.21  loss=9214.50  active=152117 precision=0.782  recall=0.729  F1=0.751  Acc(item/seq)=0.971 0.799  feature_norm=52.88
Iter 87  time=0.21  loss=9194.55  active=152117 precision=0.785  recall=0.731  F1=0.753  Acc(item/seq)=0.971 0.800  feature_norm=52.82
Iter 88  time=0.21  loss=9182.23  active=152117 precision=0.783  recall=0.728  F1=0.750  Acc(item/seq)=0.971 0.798  feature_norm=52.86
Iter 89  time=0.21  loss=9161.54  active=152117 precision=0.786  recall=0.726  F1=0.750  Acc(item/seq)=0.971 0.795  feature_norm=52.98
Iter 90  time=0.21  loss=9130.06  active=152117 precision=0.788  recall=0.728  F1=0.752  Acc(item/seq)=0.971 0.797  feature_norm=53.17
Iter 91  time=0.21  loss=9063.09  active=152117 precision=0.794  recall=0.738  F1=0.760  Acc(item/seq)=0.972 0.800  feature_norm=53.64
Iter 92  time=0.40  loss=9041.98  active=152117 precision=0.791  recall=0.739  F1=0.759  Acc(item/seq)=0.972 0.801  feature_norm=53.77
Iter 93  time=0.21  loss=9010.27  active=152117 precision=0.783  recall=0.736  F1=0.754  Acc(item/seq)=0.972 0.801  feature_norm=53.97
Iter 94  time=0.21  loss=8984.80  active=152117 precision=0.786  recall=0.740  F1=0.759  Acc(item/seq)=0.972 0.803  feature_norm=54.09
Iter 95  time=0.21  loss=8965.14  active=152117 precision=0.788  recall=0.740  F1=0.760  Acc(item/seq)=0.972 0.803  feature_norm=54.14
Iter 96  time=0.21  loss=8931.00  active=152117 precision=0.786  recall=0.737  F1=0.758  Acc(item/seq)=0.972 0.804  feature_norm=54.26
Iter 97  time=0.21  loss=8923.90  active=152117 precision=0.791  recall=0.737  F1=0.757  Acc(item/seq)=0.971 0.799  feature_norm=54.68
Iter 98  time=0.21  loss=8860.78  active=152117 precision=0.784  recall=0.735  F1=0.755  Acc(item/seq)=0.971 0.802  feature_norm=54.64
Iter 99  time=0.21  loss=8846.13  active=152117 precision=0.788  recall=0.737  F1=0.758  Acc(item/seq)=0.972 0.803  feature_norm=54.68
Iter 100 time=0.21  loss=8832.34  active=152117 precision=0.789  recall=0.737  F1=0.758  Acc(item/seq)=0.972 0.803  feature_norm=54.77
================================================
Label      Precision    Recall     F1    Support
-------  -----------  --------  -----  ---------
B-LOC          0.823     0.823  0.823        479
B-MISC         0.806     0.651  0.720        748
B-ORG          0.853     0.620  0.718        686
B-PER          0.752     0.855  0.800        703
I-LOC          0.583     0.547  0.565         64
I-MISC         0.645     0.507  0.568        215
I-ORG          0.806     0.684  0.740        396
I-PER          0.844     0.946  0.892        423
O              0.990     0.998  0.994      33973
------------------------------------------------
L-BFGS terminated with the maximum number of iterations
Total seconds required for training: 21.768

Storing the model
Number of active features: 152117 (152117)
Number of active attributes: 130306 (145602)
Number of active labels: 9 (9)
Writing labels
Writing attributes
Writing feature references for transitions
Writing feature references for attributes
Seconds required: 0.058

CRF(algorithm='lbfgs', all_possible_states=None,
  all_possible_transitions=None, averaging=None, c=None, c1=None, c2=None,
  calibration_candidates=None, calibration_eta=None,
  calibration_max_trials=None, calibration_rate=None,
  calibration_samples=None, delta=None, epsilon=None, error_sensitive=None,
  gamma=None, keep_tempfiles=None, linesearch=None, max_iterations=100,
  max_linesearch=None, min_freq=None, model_filename=None,
  num_memories=None, pa_type=None, period=None, trainer_cls=None,
  variance=None, verbose='true')
import joblib
import os

OUTPUT_PATH = "models/ner/"
OUTPUT_FILE = "crf_model"

if not os.path.exists(OUTPUT_PATH):
    os.mkdir(OUTPUT_PATH)

joblib.dump(crf, os.path.join(OUTPUT_PATH, OUTPUT_FILE))
['models/ner/crf_model']

Evaluation

Let's evaluate the output of our CRF. We'll load the model from the output file above and have it predict labels for the full test set.

As a sanity check, let's take a look at its predictions for the first test sentence. This output looks pretty good: the CRF is able to predict all four locations in the sentence correctly. It only misses the person entity, which is a strange case anyway, because it is not actually a person name.

crf = joblib.load(os.path.join(OUTPUT_PATH, OUTPUT_FILE))
y_pred = crf.predict(X_test)

example_sent = test_sents[0]

print("Sentence:", ' '.join(sent2tokens(example_sent)))
print("Predicted:", ' '.join(crf.predict([sent2features(example_sent, word2cluster)])[0]))
print("Correct:  ", ' '.join(sent2labels(example_sent)))
Sentence: Dat is in Italië , Spanje of Engeland misschien geen probleem , maar volgens ' Der Kaiser ' in Duitsland wel .
Predicted: O O O B-LOC O B-LOC O B-LOC O O O O O O O B-MISC I-MISC O O B-LOC O O
Correct:   O O O B-LOC O B-LOC O B-LOC O O O O O O O B-PER I-PER O O B-LOC O O

Now we evaluate on the full test set. We'll print out a classification report for all labels except O. If we were to include O, which far outnumbers the entity labels in our data, the average scores would be inflated artificially, simply because there's an inherently high probability that the O labels from our CRF are correct. We obtain an average F-score of 77% (micro average) across all entity types, with particularly good results for B-LOCand B-PER.

labels = list(crf.classes_)
labels.remove("O")
y_pred = crf.predict(X_test)
sorted_labels = sorted(
    labels,
    key=lambda name: (name[1:], name[0])
)

print(metrics.flat_classification_report(y_test, y_pred, labels=sorted_labels))
              precision    recall  f1-score   support

       B-LOC       0.83      0.83      0.83       774
       I-LOC       0.29      0.41      0.34        49
      B-MISC       0.84      0.61      0.71      1187
      I-MISC       0.59      0.42      0.49       410
       B-ORG       0.80      0.69      0.74       882
       I-ORG       0.74      0.66      0.70       551
       B-PER       0.80      0.90      0.85      1098
       I-PER       0.87      0.95      0.91       807

   micro avg       0.80      0.74      0.77      5758
   macro avg       0.72      0.68      0.70      5758
weighted avg       0.80      0.74      0.76      5758

Now we can also look at the most likely transitions the CRF has identified, and at the top features for every label. We'll do this with the eli5 library, which helps us explain the predictions of machine learning models.

The top transitions are quite intuitive: the most likely transitions are those within the same entity type (from a B-label to an O-label), and those where a B-label follows an O-label.

The features, too, make sense. For example, if a word does not start with an uppercase letter, it is unlikely to be an entity. By contrast, a word is very likely to be a location if it ends in ië, which is indeed a very common suffix for locations in Dutch. Notice also how informative the embedding clusters are: for all entity types, the word clusters form some of the most informative features for the CRF.

import eli5

eli5.show_weights(crf, top=30)
From \ To O B-LOC I-LOC B-MISC I-MISC B-ORG I-ORG B-PER I-PER
O 4.141 4.583 0.0 4.141 0.0 4.366 0.0 3.819 0.0
B-LOC -0.248 -0.279 7.101 0.0 0.0 0.0 0.0 -0.661 0.0
I-LOC -1.062 -0.235 5.967 0.0 0.0 0.0 0.0 0.0 0.0
B-MISC -0.985 0.655 0.0 -0.316 7.73 0.551 0.0 0.46 0.0
I-MISC -1.781 0.0 0.0 -0.382 7.769 1.145 0.0 -0.719 0.0
B-ORG -0.261 0.0 0.0 -0.809 0.0 0.0 7.803 0.106 0.0
I-ORG -0.794 0.0 0.0 0.0 0.0 0.0 7.174 0.084 0.0
B-PER 0.31 -0.346 0.0 -0.611 0.0 0.0 0.0 -1.408 8.68
I-PER 0.104 0.0 0.0 0.0 0.0 0.0 0.0 0.0 6.804
y=O top features y=B-LOC top features y=I-LOC top features y=B-MISC top features y=I-MISC top features y=B-ORG top features y=I-ORG top features y=B-PER top features y=I-PER top features
Weight? Feature
+3.692 word.istitle=False
+3.312 word.isupper=False
+2.211 -1:word.lower="
+1.999 -1:word.lower=+
+1.835 EOS
+1.820 BOS
+1.740 word.cluster=158
+1.684 word.cluster=415
+1.590 word[-2:]=ag
+1.588 word.cluster=195
+1.556 word[-3:]=dag
+1.520 word.cluster=185
+1.520 word.cluster=178
+1.481 word.cluster=24
+1.477 word.cluster=177
+1.475 word.lower=u
… 108173 more positive …
… 6671 more negative …
-1.515 word.cluster=375
-1.538 -1:word.lower=ronde
-1.561 +1:word.lower=(
-1.618 0
-1.693 -1:word.lower=de
-1.740 postag=N
-1.808 word[-2:]='s
-1.843 word[-3:]=ple
-1.853 postag=Misc
-1.914 word.cluster=6
-1.945 +2:word.lower=1
-1.990 word.cluster=111
-2.068 word.isupper=True
-2.447 word.istitle=True
Weight? Feature
+3.734 word.cluster=325
+3.650 word.cluster=375
+3.466 word.cluster=68
+3.169 word.cluster=139
+2.020 -1:word.lower=in
+1.995 word.cluster=143
+1.973 word.cluster=476
+1.828 word.cluster=102
+1.617 word[-2:]=ië
+1.538 -1:word.lower=(
+1.516 -1:word.lower=uit
+1.495 word.cluster=154
+1.441 +1:word.lower=/
+1.417 +2:word.lower=-
+1.404 +2:word.lower=tel.
+1.379 -1:word.lower=het
+1.374 word.cluster=440
+1.369 word.lower=sint-michiels
+1.337 word[-3:]=urs
+1.320 word.lower=futuroscope
+1.301 word[-3:]=ope
+1.301 word.cluster=363
+1.278 word.lower=vst
+1.278 word[-3:]=VSt
+1.257 word[-3:]=den
+1.245 word.cluster=146
+1.244 word.cluster=116
+1.212 word[-2:]=St
+1.198 word.cluster=492
… 4441 more positive …
… 668 more negative …
-1.322 -2:word.lower=ronde
Weight? Feature
+1.778 word.cluster=238
+1.488 +2:word.lower=m
+1.367 word.cluster=161
+0.996 -1:word.lower=col
+0.977 word[-2:]=rk
+0.852 -1:word.istitle=False
+0.821 word[-2:]=al
+0.810 word.lower=york
+0.810 word[-3:]=ork
+0.809 word[-3:]=eum
+0.781 -1:word.lower=san
+0.703 word.lower=abeba
+0.703 -1:word.lower=addis
+0.702 word[-3:]=eba
+0.701 -1:word.isupper=True
+0.683 -2:word.lower='
+0.682 +2:word.lower=,
+0.677 word[-3:]=den
+0.676 word[-3:]=aan
+0.674 word[-2:]=um
+0.670 word.lower=staten
+0.669 word.cluster=38
+0.661 postag=Prep
+0.650 -2:word.lower=col
+0.650 -1:postag=Art
+0.648 -1:word.lower=museum
+0.645 word[-2:]=ba
+0.643 word.lower=hotel
… 1294 more positive …
… 131 more negative …
-0.674 EOS
-0.781 postag=Adj
Weight? Feature
+3.318 word.cluster=23
+2.494 word.cluster=100
+2.419 word.cluster=39
+2.097 +2:word.lower=1
+2.039 word.cluster=294
+2.036 word.cluster=338
+1.786 word[-2:]='s
+1.772 word.cluster=11
+1.700 word.lower=sport
+1.646 word.lower=buitenland
+1.628 word[-2:]=se
+1.618 word.cluster=222
+1.535 word.cluster=427
+1.505 word.lower=journaal
+1.478 word.cluster=40
+1.417 word[-3:]=nse
+1.411 postag=Adj
+1.374 word.cluster=111
+1.350 word.lower=aula
+1.347 word.lower=tobin-taks
+1.329 word.cluster=218
+1.301 word.cluster=128
+1.274 word[-3:]=ula
+1.223 BOS
+1.213 word.lower=tobin-heffing
+1.204 word.lower=ronde
+1.204 word[-2:]=ks
+1.197 word.lower=v-plan
+1.184 word[-2:]=ex
… 5588 more positive …
… 853 more negative …
-1.256 -1:word.isupper=True
Weight? Feature
+1.792 -2:word.lower=ronde
+1.684 -1:word.isupper=True
+1.567 -1:word.lower=ronde
+1.332 word.cluster=37
+1.323 +1:word.lower=ned
+1.316 word.cluster=325
+1.298 word.cluster=1
+1.274 word.lower=leven
+1.215 -1:word.istitle=True
+1.201 -1:postag=Num
+1.182 -1:word.lower=euro
+1.161 word.cluster=86
+1.147 word.lower=bijsluiter
+1.103 -1:word.lower=financiële
+1.100 +2:word.lower=ned
+1.096 word[-2:]=ap
+1.068 postag=Num
+1.042 word.cluster=413
+1.029 word[-2:]=00
+1.012 -1:word.lower=de
+1.010 -1:postag=Prep
+0.979 word.cluster=279
+0.954 -1:word.lower=travel
+0.950 -1:word.lower=formule
+0.947 word.lower=financiële
+0.945 -1:word.lower=brussel
+0.916 +1:word.lower='
… 3172 more positive …
… 465 more negative …
-1.115 postag=V
-1.467 +1:postag=Num
-2.953 -1:postag=Adj
Weight? Feature
+2.798 word.cluster=424
+2.635 word.cluster=228
+2.121 word[-3:]=com
+1.991 word.cluster=187
+1.974 word.lower=quizpeople
+1.922 word[-3:]=ple
+1.848 +1:word.lower=morgen
+1.798 word.cluster=250
+1.683 word.cluster=83
+1.560 word.cluster=29
+1.538 word[-2:]=ga
+1.525 BOS
+1.500 word[-3:]=bel
+1.476 -2:word.lower=minister
+1.422 word.cluster=207
+1.418 word.cluster=249
+1.412 word.lower=waterleau
+1.411 word.isupper=True
+1.384 word.lower=belga
+1.384 word[-3:]=lga
+1.328 word[-3:]=our
+1.320 word[-2:]=om
+1.319 word[-2:]=to
+1.299 word.lower=freenet
+1.277 word.lower=baan
+1.240 -1:word.lower=bij
+1.227 word[-3:]=lux
+1.227 word[-2:]=co
+1.222 word.cluster=21
… 3591 more positive …
… 469 more negative …
-1.156 word.isupper=False
Weight? Feature
+1.338 word.lower=morgen
+1.304 word.cluster=403
+1.243 word.cluster=413
+1.200 -1:word.lower=vlaams
+1.141 word.cluster=187
+1.120 word[-3:]=gen
+1.101 word[-3:]=ion
+1.057 -1:word.lower=radio
+1.028 word.cluster=321
+0.970 -1:word.lower=ned
+0.849 word[-2:]=nt
+0.841 word.cluster=143
+0.805 -2:word.lower=voor
+0.767 word.cluster=375
+0.767 -1:postag=Misc
+0.763 word[-2:]=es
+0.760 word[-2:]=3
+0.760 word.lower=3
+0.760 word[-3:]=3
+0.745 word.cluster=478
+0.733 word.cluster=411
+0.719 word[-2:]=ey
+0.705 word[-3:]=ola
+0.702 word[-2:]=le
… 2209 more positive …
… 250 more negative …
-0.777 -2:postag=V
-0.793 word[-2:]=in
-0.824 -1:word.lower=financiële
-0.842 -1:postag=Num
-1.090 0
-1.162 postag=V
Weight? Feature
+3.523 word.cluster=489
+2.888 word.cluster=204
+2.818 word.cluster=301
+2.804 word.cluster=3
+2.765 word.cluster=246
+2.444 word.cluster=337
+2.419 word.cluster=6
+2.361 word.cluster=326
+2.199 word.cluster=296
+2.069 word.cluster=87
+1.975 word.cluster=349
+1.727 word.cluster=12
+1.575 word.cluster=105
+1.569 word.cluster=350
+1.448 word.lower=bode
+1.439 word.cluster=85
+1.347 word[-2:]=ov
+1.340 +1:word.lower=(
+1.333 -1:word.lower=-
+1.320 word.cluster=111
+1.307 +2:word.lower=--
+1.301 word[-2:]=ff
+1.300 -1:postag=V
+1.248 word[-2:]=án
+1.247 -1:word.lower=volgens
+1.231 word[-3:]=par
… 7464 more positive …
… 851 more negative …
-1.349 word[-3:]=ing
-1.360 word[-2:]=ur
-1.610 -1:word.lower=in
-1.937 -1:postag=Art
Weight? Feature
+1.748 -1:word.lower=van
+1.425 word.cluster=3
+1.313 word.cluster=388
+1.287 word.cluster=450
+1.250 word.cluster=6
+1.231 word.cluster=249
+1.152 +1:word.lower=(
+1.127 +2:word.lower=die
+1.044 word.lower=gucht
+0.970 word.cluster=337
+0.927 -1:word.lower=pu
+0.927 word[-3:]=nfu
+0.927 word.lower=shenfu
+0.925 word[-2:]=fu
+0.921 word.cluster=296
+0.895 word.lower=grauwe
+0.875 word[-3:]=uwe
+0.869 word[-2:]=we
+0.867 -1:word.lower=de
+0.841 word[-2:]=ck
+0.815 -1:postag=Pron
+0.814 word.lower=den
… 4971 more positive …
… 548 more negative …
-0.814 word[-2:]=ma
-0.824 word[-3:]=aan
-0.851 word[-2:]=al
-0.961 word.cluster=238
-1.032 word.istitle=False
-1.055 postag=V
-1.071 word[-3:]=ter
-1.079 word[-2:]=um

Finding the optimal hyperparameters

So far we've trained a model with the default parameters. It's unlikely that these will give us the best performance possible. Therefore we're going to search automatically for the best hyperparameter settings by iteratively training different models and evaluating them. Eventually we'll pick the best one.

Here we'll focus on two parameters: c1 and c2. These are the parameters for L1 and L2 regularization, respectively. Regularization prevents overfitting on the training data by adding a penalty to the loss function. In L1 regularization, this penalty is the sum of the absolute values of the weights; in L2 regularization, it is the sum of the squared weights. L1 regularization performs a type of feature selection, as it assigns 0 weight to irrelevant features. L2 regularization, by contrast, makes the weight of irrelevant features small, but not necessarily zero. L1 regularization is often called the Lasso method, L2 is called the Ridge method, and the linear combination of both is called Elastic Net regularization.

We define the parameter space for c1 and c2 and use the flat F1-score to compare the individual models. We'll rely on three-fold cross validation to score each of the 50 candidates. We use a randomized search, which means we're not going to try out all specified parameter settings, but instead, we'll let the process sample randomly from the distributions we've specified in the parameter space. It will do this 50 (n_iter) times. This process takes a while, but it's worth the wait.

import scipy
from sklearn.metrics import make_scorer
from sklearn.model_selection import RandomizedSearchCV

crf = crfsuite.CRF(
    algorithm='lbfgs',
    max_iterations=100,
    all_possible_transitions=True
)

params_space = {
    'c1': scipy.stats.expon(scale=0.5),
    'c2': scipy.stats.expon(scale=0.05),
}

f1_scorer = make_scorer(metrics.flat_f1_score,
                        average='weighted', labels=labels)

rs = RandomizedSearchCV(crf, params_space,
                        cv=3,
                        verbose=1,
                        n_jobs=-1,
                        n_iter=50,
                        scoring=f1_scorer)
rs.fit(X_train, y_train)
Fitting 3 folds for each of 50 candidates, totalling 150 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done  18 tasks      | elapsed:  3.2min
[Parallel(n_jobs=-1)]: Done 150 out of 150 | elapsed: 24.3min finished
RandomizedSearchCV(cv=3, error_score='raise-deprecating',
          estimator=CRF(algorithm='lbfgs', all_possible_states=None,
  all_possible_transitions=True, averaging=None, c=None, c1=None, c2=None,
  calibration_candidates=None, calibration_eta=None,
  calibration_max_trials=None, calibration_rate=None,
  calibration_samples=None, delta=None, epsilon=None, error...e,
  num_memories=None, pa_type=None, period=None, trainer_cls=None,
  variance=None, verbose=False),
          fit_params=None, iid='warn', n_iter=50, n_jobs=-1,
          param_distributions={'c1': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f9947f04e10>, 'c2': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7f9947f04c88>},
          pre_dispatch='2*n_jobs', random_state=None, refit=True,
          return_train_score='warn',
          scoring=make_scorer(flat_f1_score, average=weighted, labels=['B-ORG', 'B-MISC', 'B-PER', 'I-PER', 'B-LOC', 'I-MISC', 'I-ORG', 'I-LOC']),
          verbose=1)

Let's take a look at the best hyperparameter settings. Our random search suggests a combination of L1 and L2 normalization.

print('best params:', rs.best_params_)
print('best CV score:', rs.best_score_)
print('model size: {:0.2f}M'.format(rs.best_estimator_.size_ / 1000000))
best params: {'c1': 0.08869645933566639, 'c2': 0.005642379370340676}
best CV score: 0.7608794798691931
model size: 1.06M

To find out what precision, recall and F1-score this translates to, we take the best estimator from our random search and evaluate it on the test set. This indeed shows a nice improvement from our initial model. We've gone from an average F1-score of 77% to 79.1%. Both precision and recall have improved, and we see a positive result for all four entity types.

best_crf = rs.best_estimator_
y_pred = best_crf.predict(X_test)
print(metrics.flat_classification_report(
    y_test, y_pred, labels=sorted_labels, digits=3
))
              precision    recall  f1-score   support

       B-LOC      0.849     0.863     0.856       774
       I-LOC      0.359     0.571     0.441        49
      B-MISC      0.847     0.622     0.717      1187
      I-MISC      0.664     0.415     0.511       410
       B-ORG      0.806     0.727     0.764       882
       I-ORG      0.772     0.677     0.721       551
       B-PER      0.834     0.903     0.867      1098
       I-PER      0.892     0.958     0.924       807

   micro avg      0.823     0.761     0.791      5758
   macro avg      0.753     0.717     0.725      5758
weighted avg      0.820     0.761     0.784      5758

Conclusions

Conditional Random Fields have lost some of their popularity since the advent of neural-network models. Still, they can be very effective for named entity recognition, particularly when word embedding information is taken into account.