2023-11-03 2023-11-03

Hugging Face TransformersのHfArgumentParserとargparse.ArgumentParserを併用する

個人的に少し困ったので、同様の悩みを抱える人の助けになればと思い、共有します。

状況

  • Transformersを使わないコードを書いていたのに途中で使うこととなり、Trainerを使いたくなった。
  • 既にargparse.ArgumentParserを使ってコマンドライン引数を受け取るようにしていた。
  • Trainerに渡す引数はTrainingArgumentsで管理し、わざわざargparse.ArgumentParserに引数を追加するようなことはしたくない。

対処方法

下記のように、argparse.ArgumentParserを継承したCustomArgumentParserを定義し、parse_args()内でargs, extras = self.parser.parse_known_args()を呼ぶ所がキモです。

extrasself.parserに定義されていない引数のリストが格納されているので、これをself.hf_parserに渡せば万事解決となります。

import argparse
from transformers import HfArgumentParser, TrainingArguments

class CustomArgumentParser(argparse.ArgumentParser):
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.hf_parser = HfArgumentParser(TrainingArguments)

        # Define any custom arguments using argparse
        self.parser.add_argument(
            "--dataset_path",
            type=str,
            required=True,
            help="Path to the dataset."
        )
        self.parser.add_argument(
            "--tokenizer_name_or_path", 
            type=str, 
            required=True,
            help="Path to the tokenizer."
        )
        self.parser.add_argument(
            "--model_name_or_path", 
            type=str, 
            required=True,
            help="Path to the model."
        )
        self.parser.add_argument(
            "--cache_dir", 
            type=str, 
            default=None,
            help="Path to the cache directory."
        )

    def parse_args(self):
        args, extras = self.parser.parse_known_args()
        training_args = self.hf_parser.parse_args_into_dataclasses(extras)[0]
        return args, training_args