2023-11-03
2023-11-03
Hugging Face TransformersのHfArgumentParserとargparse.ArgumentParserを併用する
個人的に少し困ったので、同様の悩みを抱える人の助けになればと思い、共有します。
状況
- Transformersを使わないコードを書いていたのに途中で使うこととなり、Trainerを使いたくなった。
- 既に
argparse.ArgumentParser
を使ってコマンドライン引数を受け取るようにしていた。 - Trainerに渡す引数は
TrainingArguments
で管理し、わざわざargparse.ArgumentParser
に引数を追加するようなことはしたくない。
対処方法
下記のように、argparse.ArgumentParser
を継承したCustomArgumentParser
を定義し、parse_args()
内でargs, extras = self.parser.parse_known_args()
を呼ぶ所がキモです。
extras
にself.parser
に定義されていない引数のリストが格納されているので、これをself.hf_parser
に渡せば万事解決となります。
import argparse
from transformers import HfArgumentParser, TrainingArguments
class CustomArgumentParser(argparse.ArgumentParser):
def __init__(self):
self.parser = argparse.ArgumentParser()
self.hf_parser = HfArgumentParser(TrainingArguments)
# Define any custom arguments using argparse
self.parser.add_argument(
"--dataset_path",
type=str,
required=True,
help="Path to the dataset."
)
self.parser.add_argument(
"--tokenizer_name_or_path",
type=str,
required=True,
help="Path to the tokenizer."
)
self.parser.add_argument(
"--model_name_or_path",
type=str,
required=True,
help="Path to the model."
)
self.parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="Path to the cache directory."
)
def parse_args(self):
args, extras = self.parser.parse_known_args()
training_args = self.hf_parser.parse_args_into_dataclasses(extras)[0]
return args, training_args