2024-06-20

非SentencePieceベースのBPEトークナイザを語彙拡張する

はじめに

LLaMA2の登場以後、特定言語向けの言語モデル構築のため、語彙拡張を伴う英語モデルのターゲット言語のデータによる追加学習が盛んに行われています。

語彙拡張の目的はフラグメンテーション1の改善にあり、これにより推論効率の改善が見込まれます。

語彙拡張のアイディア自体はシンプルですが、トークナイザの実装方法によってどのように拡張するかは変わってきます。本稿では、非SentencePieceベースのBPEトークナイザを語彙拡張する手順について共有します。(SentencePieceベースのBPEトークナイザを語彙拡張する方法については、SentencePieceのレポジトリ内に説明があります。)

前提

SentencePiece(LLaMA2やMistralなどが使用)と非SentencePieceベース(LLaMA3やOLMoなど)のBPEトークナイザの主な違いとして、byte-levelか否かが挙げられます。

SentencePiece系のBPEはbyte-fallbackオプションが適用されており、UNKトークン(=トークナイズできない)の発生を防いでいます。

他方、近年の非SentencePieceベースのBPEはbyte-level BPEとなっており、入力をUTF-8でエンコードされたバイト列に変換してからトークナイズを行っています。バイト列に対してトークナイズを行うので、UNKトークンは発生しません。2

したがって、同じアルゴリズムであっても、文字列の前処理・後処理が若干異なります。この違いは、実際にtransformersトークナイザのメタデータのうち、pre_tokenizerdecoder部分からも確認できます。

import json
from transformers import AutoTokenizer

# LLaMA2
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
print(tokenizer_json["pre_tokenizer"])
# None
print(tokenizer_json["decoder"])
# {'type': 'Sequence', 'decoders': [{'type': 'Replace', 'pattern': {'String': '▁'}, 'content': ' '}, {'type': 'ByteFallback'}, {'type': 'Fuse'}, {'type': 'Strip', 'content': ' ', 'start': 1, 'stop': 0}]}

# LLaMA3
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
print(tokenizer_json["pre_tokenizer"])
# {'type': 'Sequence', 'pretokenizers': [{'type': 'Split', 'pattern': {'Regex': "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}, 'behavior': 'Isolated', 'invert': False}, {'type': 'ByteLevel', 'add_prefix_space': False, 'trim_offsets': True, 'use_regex': False}]}
print(tokenizer_json["decoder"])
# {'type': 'ByteLevel', 'add_prefix_space': True, 'trim_offsets': True, 'use_regex': True}

実装

はじめに、拡張元のトークナイザtokenizerとターゲットとなる言語で学習された補助のトークナイザaux_tokenizerを用意しておきます。

ここでは、比較的語彙拡張の効果が見られやすいギリシャ語を例に取ります。

1. 読み込み

拡張元のトークナイザとターゲット言語の補助トークナイザを読み込み、マージルールや語彙の辞書を取得します。

以下の例では、LLaMA3を拡張元のトークナイザとしています。 補助のトークナイザは、語彙サイズを5万トークンとし、ギリシャ語のCC-100コーパスから無作為に$2^{20}$文を抽出したデータセットで学習させたものです。それ以外の学習設定はLLaMA3に準じています。3

import json
import copy

from transformers import AutoTokenizer
from tokenizers.models import BPE

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
vocab = tokenizer.get_vocab()
tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
merges = tokenizer_json["model"]["merges"]

aux_tokenizer = AutoTokenizer.from_pretrained("atsuki-yamaguchi/cc100-el-50k")
aux_tokenizer_json = json.loads(aux_tokenizer._tokenizer.to_str())
aux_merges = aux_tokenizer_json["model"]["merges"]

2. 語彙とマージルールの追加

拡張元のトークナイザの語彙の辞書とマージルールのリストに、語彙と対応するマージルールを追加します。ここでは、1万個の新規語彙を最大値として追加します。4

# merge the tokenizers
num_new_token = 0
max_new_token = 10000
ret_vocab = copy.copy(vocab)
ret_merges = []
old_merges = copy.copy(merges)
for merge in aux_merges:
    # vocab
    token_1, token_2 = merge.split(" ")
    token = token_1 + token_2
    if num_new_token < max_new_token:
        if token_1 not in ret_vocab and token_2 not in ret_vocab: # both are new
            ret_vocab[token_1] = len(vocab) + num_new_token
            ret_vocab[token_2] = len(vocab) + num_new_token + 1
            num_new_token += 2
        elif token_1 not in ret_vocab and token_2 in ret_vocab: # new + existing
            ret_vocab[token_1] = len(vocab) + num_new_token
            num_new_token += 1
        elif token_1 in ret_vocab and token_2 not in ret_vocab: # old + existing
            ret_vocab[token_2] = len(vocab) + num_new_token
            num_new_token += 1
        else: # both are existing tokens
            pass
        if token not in ret_vocab:
            ret_vocab[token] = len(vocab) + num_new_token
            num_new_token += 1
    # merge
    if merge in merges:
        old_merges.remove(merge)
        ret_merges.append(merge)
    elif token in ret_vocab and token_1 in ret_vocab and token_2 in ret_vocab:
        ret_merges.append(merge)

3. トークナイザの再学習

拡張した語彙とマージルールを基に、BPEトークナイザのインスタンスを作成し、上書きします。

# retrain tokenizer
merges = ret_merges + old_merges
vocab = ret_vocab
tokenizer.backend_tokenizer.model = BPE(
    vocab=vocab,
    merges=[(merge.split(' ')[0], merge.split(' ')[1]) for merge in merges],
    fuse_unk=False,
)

4. 保存

最後に上書きしたトークナイザを保存して完了です。

# save
tokenizer.save_pretrained("/path/to/output/dir")

効果

下記の例文が語彙拡張前後のトークナイザで何トークンになるかを計測します。

Μου είπαν ότι, θα έπρεπε να καλέσω έναν άντρα στο τέλος για να συναντηθούμε. Ερώτηση: Ο τύπος εμφανίστηκε λίγο αργά. Αληθές, Ψευδές, ή Κανένα από τα δύο; Απάντηση: Κανένα από τα δύο

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
modified_tokenizer = AutoTokenizer.from_pretrained("/path/to/output/dir")

text = "Μου είπαν ότι, θα έπρεπε να καλέσω έναν άντρα στο τέλος για να συναντηθούμε. Ερώτηση: Ο τύπος εμφανίστηκε λίγο αργά. Αληθές, Ψευδές, ή Κανένα από τα δύο; Απάντηση: Κανένα από τα δύο"

print(len(tokenizer.encode(text)))
# 81

print(len(modified_tokenizer.encode(text)))
# 46

結果として1万トークンをLLaMA3トークナイザに新規に追加することで、35トークン減少しました。

おわりに

SentencePieceベースのBPEトークナイザの語彙拡張の例はたくさん見かけますが、それ以外の実践例をあまり見かけたことがなかったので今回の記事作成に至りました。 何かの役に立てば幸いです。

  1. ほとんどの大規模言語モデルは英語データを中心に学習されているため、サポート外の言語を使用する際にトークン数が増加してしまう事象が報告されている。詳しくは、Ahia et al. (2023)などを参照されたい。 

  2. 詳しくは、minbpeを参照されたい。 

  3. 学習にはtokenizer.train_new_from_iterator()を使うのが便利です。 

  4. マージルールはトークンの出現頻度順にソートされているため、リストを順繰りに処理することによりトークンの頻度順に語彙を追加できます。詳しくは、transformersのissue等を確認してください。