| import sys | |
| from datasets import load_dataset | |
| from transformers import TrainingArguments | |
| from span_marker import SpanMarkerModel, Trainer | |
| # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels | |
| dataset = load_dataset("gwlms/germeval2014") | |
| labels = dataset["train"].features["ner_tags"].feature.names | |
| # Initialize a SpanMarker model using a pretrained BERT-style encoder | |
| model_name = sys.argv[1] | |
| model = SpanMarkerModel.from_pretrained( | |
| model_name, | |
| labels=labels, | |
| # SpanMarker hyperparameters: | |
| model_max_length=256, | |
| marker_max_length=128, | |
| entity_max_length=8, | |
| ) | |
| args = TrainingArguments( | |
| output_dir="/tmp", | |
| per_device_eval_batch_size=64, | |
| ) | |
| # Initialize the trainer using our model, training args & dataset, and train | |
| trainer = Trainer( | |
| model=model, | |
| args=args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["validation"], | |
| ) | |
| print("Evaluating on development set...") | |
| dev_metrics = trainer.evaluate(dataset["validation"], metric_key_prefix="eval") | |
| print(dev_metrics) | |
| print("Evaluating on test set...") | |
| test_metrics = trainer.evaluate(dataset["test"], metric_key_prefix="test") | |
| print(test_metrics) | |