Spaces:
Runtime error
Runtime error
Christian Rene Thelen
commited on
Commit
·
963cb02
1
Parent(s):
21f86d4
Initial Commit
Browse files- .gitignore +222 -0
- .python-version +1 -0
- Dockerfile +79 -0
- README.md +50 -0
- docker-compose.yml +46 -0
- requirements.txt +263 -0
- setup_env.sh +21 -0
- subtask2_final_gradio.py +618 -0
- subtask_1/exp019-4.py +224 -0
- subtask_1/grid_cv_results.exp019-2.csv +16 -0
- subtask_1/grid_cv_results.exp019-3.csv +26 -0
- subtask_1/grid_cv_results.exp019-4.csv +49 -0
- subtask_1/submission_subtask1-2.ipynb +608 -0
- subtask_1/submission_subtask1.ipynb +719 -0
- subtask_2/exp027-1.py +736 -0
- subtask_2/exp027-2.py +736 -0
- subtask_2/exp027-2_retraining.py +736 -0
- subtask_2/submission_subtask2-2.ipynb +0 -0
- subtask_2/submission_subtask2.ipynb +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### VisualStudioCode template
|
| 2 |
+
.vscode/*
|
| 3 |
+
!.vscode/settings.json
|
| 4 |
+
!.vscode/tasks.json
|
| 5 |
+
!.vscode/launch.json
|
| 6 |
+
!.vscode/extensions.json
|
| 7 |
+
!.vscode/*.code-snippets
|
| 8 |
+
|
| 9 |
+
# Local History for Visual Studio Code
|
| 10 |
+
.history/
|
| 11 |
+
|
| 12 |
+
# Built Visual Studio Code Extensions
|
| 13 |
+
*.vsix
|
| 14 |
+
|
| 15 |
+
### JupyterNotebooks template
|
| 16 |
+
# gitignore template for Jupyter Notebooks
|
| 17 |
+
# website: http://jupyter.org/
|
| 18 |
+
|
| 19 |
+
.ipynb_checkpoints
|
| 20 |
+
*/.ipynb_checkpoints/*
|
| 21 |
+
|
| 22 |
+
# IPython
|
| 23 |
+
profile_default/
|
| 24 |
+
ipython_config.py
|
| 25 |
+
|
| 26 |
+
# Remove previous ipynb_checkpoints
|
| 27 |
+
# git rm -r .ipynb_checkpoints/
|
| 28 |
+
|
| 29 |
+
### Python template
|
| 30 |
+
# Byte-compiled / optimized / DLL files
|
| 31 |
+
__pycache__/
|
| 32 |
+
*.py[cod]
|
| 33 |
+
*$py.class
|
| 34 |
+
|
| 35 |
+
# C extensions
|
| 36 |
+
*.so
|
| 37 |
+
|
| 38 |
+
# Distribution / packaging
|
| 39 |
+
.Python
|
| 40 |
+
build/
|
| 41 |
+
develop-eggs/
|
| 42 |
+
dist/
|
| 43 |
+
downloads/
|
| 44 |
+
eggs/
|
| 45 |
+
.eggs/
|
| 46 |
+
lib/
|
| 47 |
+
lib64/
|
| 48 |
+
parts/
|
| 49 |
+
sdist/
|
| 50 |
+
var/
|
| 51 |
+
wheels/
|
| 52 |
+
share/python-wheels/
|
| 53 |
+
*.egg-info/
|
| 54 |
+
.installed.cfg
|
| 55 |
+
*.egg
|
| 56 |
+
MANIFEST
|
| 57 |
+
|
| 58 |
+
# PyInstaller
|
| 59 |
+
# Usually these files are written by a python script from a template
|
| 60 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 61 |
+
*.manifest
|
| 62 |
+
*.spec
|
| 63 |
+
|
| 64 |
+
# Installer logs
|
| 65 |
+
pip-log.txt
|
| 66 |
+
pip-delete-this-directory.txt
|
| 67 |
+
|
| 68 |
+
# Unit test / coverage reports
|
| 69 |
+
htmlcov/
|
| 70 |
+
.tox/
|
| 71 |
+
.nox/
|
| 72 |
+
.coverage
|
| 73 |
+
.coverage.*
|
| 74 |
+
.cache
|
| 75 |
+
nosetests.xml
|
| 76 |
+
coverage.xml
|
| 77 |
+
*.cover
|
| 78 |
+
*.py,cover
|
| 79 |
+
.hypothesis/
|
| 80 |
+
.pytest_cache/
|
| 81 |
+
cover/
|
| 82 |
+
|
| 83 |
+
# Translations
|
| 84 |
+
*.mo
|
| 85 |
+
*.pot
|
| 86 |
+
|
| 87 |
+
# Django stuff:
|
| 88 |
+
*.log
|
| 89 |
+
local_settings.py
|
| 90 |
+
db.sqlite3
|
| 91 |
+
db.sqlite3-journal
|
| 92 |
+
|
| 93 |
+
# Flask stuff:
|
| 94 |
+
instance/
|
| 95 |
+
.webassets-cache
|
| 96 |
+
|
| 97 |
+
# Scrapy stuff:
|
| 98 |
+
.scrapy
|
| 99 |
+
|
| 100 |
+
# Sphinx documentation
|
| 101 |
+
docs/_build/
|
| 102 |
+
|
| 103 |
+
# PyBuilder
|
| 104 |
+
.pybuilder/
|
| 105 |
+
target/
|
| 106 |
+
|
| 107 |
+
# Jupyter Notebook
|
| 108 |
+
.ipynb_checkpoints
|
| 109 |
+
|
| 110 |
+
# IPython
|
| 111 |
+
profile_default/
|
| 112 |
+
ipython_config.py
|
| 113 |
+
|
| 114 |
+
# pyenv
|
| 115 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 116 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 117 |
+
# .python-version
|
| 118 |
+
|
| 119 |
+
# pipenv
|
| 120 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 121 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 122 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 123 |
+
# install all needed dependencies.
|
| 124 |
+
#Pipfile.lock
|
| 125 |
+
|
| 126 |
+
# poetry
|
| 127 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 128 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 129 |
+
# commonly ignored for libraries.
|
| 130 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 131 |
+
#poetry.lock
|
| 132 |
+
|
| 133 |
+
# pdm
|
| 134 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 135 |
+
#pdm.lock
|
| 136 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 137 |
+
# in version control.
|
| 138 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 139 |
+
.pdm.toml
|
| 140 |
+
.pdm-python
|
| 141 |
+
.pdm-build/
|
| 142 |
+
|
| 143 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 144 |
+
__pypackages__/
|
| 145 |
+
|
| 146 |
+
# Celery stuff
|
| 147 |
+
celerybeat-schedule
|
| 148 |
+
celerybeat.pid
|
| 149 |
+
|
| 150 |
+
# SageMath parsed files
|
| 151 |
+
*.sage.py
|
| 152 |
+
|
| 153 |
+
# Environments
|
| 154 |
+
.env
|
| 155 |
+
.venv
|
| 156 |
+
env/
|
| 157 |
+
venv/
|
| 158 |
+
ENV/
|
| 159 |
+
env.bak/
|
| 160 |
+
venv.bak/
|
| 161 |
+
|
| 162 |
+
# Spyder project settings
|
| 163 |
+
.spyderproject
|
| 164 |
+
.spyproject
|
| 165 |
+
|
| 166 |
+
# Rope project settings
|
| 167 |
+
.ropeproject
|
| 168 |
+
|
| 169 |
+
# mkdocs documentation
|
| 170 |
+
/site
|
| 171 |
+
|
| 172 |
+
# mypy
|
| 173 |
+
.mypy_cache/
|
| 174 |
+
.dmypy.json
|
| 175 |
+
dmypy.json
|
| 176 |
+
|
| 177 |
+
# Pyre type checker
|
| 178 |
+
.pyre/
|
| 179 |
+
|
| 180 |
+
# pytype static type analyzer
|
| 181 |
+
.pytype/
|
| 182 |
+
|
| 183 |
+
# Cython debug symbols
|
| 184 |
+
cython_debug/
|
| 185 |
+
|
| 186 |
+
# PyCharm
|
| 187 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 188 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 189 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 190 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 191 |
+
#.idea/
|
| 192 |
+
|
| 193 |
+
*.pkl
|
| 194 |
+
*.zip
|
| 195 |
+
share-GermEval2025-data
|
| 196 |
+
opt/
|
| 197 |
+
experiments/*/trainer
|
| 198 |
+
*.pth
|
| 199 |
+
experiments/*/wandb
|
| 200 |
+
experiments/*/trainer*/
|
| 201 |
+
experiments/*/models_*/
|
| 202 |
+
*.npy
|
| 203 |
+
experiments/exp*/*_exp_fold?/
|
| 204 |
+
experiments/exp*/exp*-fold-?/
|
| 205 |
+
experiments/exp???/wandb_logs/
|
| 206 |
+
|
| 207 |
+
experiments/exp007-large/fp16/
|
| 208 |
+
experiments/exp008/exp008-2/
|
| 209 |
+
experiments/exp008/exp008-3/
|
| 210 |
+
experiments/exp008/exp008-4/
|
| 211 |
+
experiments/exp008/exp008/
|
| 212 |
+
experiments/exp011/exp011-9/
|
| 213 |
+
experiments/exp012-fixed/exp012-fixed-2/
|
| 214 |
+
experiments/exp012-fixed/exp012-fixed-3/
|
| 215 |
+
experiments/exp012-fixed/exp012-fixed-4/
|
| 216 |
+
experiments/exp012-fixed/exp012-fixed-5/
|
| 217 |
+
experiments/exp012-fixed/exp012-fixed/
|
| 218 |
+
experiments/exp012/exp012-2/
|
| 219 |
+
experiments/exp012/exp012-3/
|
| 220 |
+
experiments/exp012/exp012/
|
| 221 |
+
experiments/exp013/exp013-2/
|
| 222 |
+
experiments/exp013/exp013/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12.9
|
Dockerfile
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use NVIDIA CUDA base image with Python 3.12.9
|
| 2 |
+
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 7 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 8 |
+
ENV PYTHON_VERSION=3.12.9
|
| 9 |
+
|
| 10 |
+
# Install system dependencies
|
| 11 |
+
RUN apt-get update && apt-get install -y \
|
| 12 |
+
software-properties-common \
|
| 13 |
+
build-essential \
|
| 14 |
+
libssl-dev \
|
| 15 |
+
libffi-dev \
|
| 16 |
+
libsqlite3-dev \
|
| 17 |
+
libreadline-dev \
|
| 18 |
+
libbz2-dev \
|
| 19 |
+
libncurses5-dev \
|
| 20 |
+
libncursesw5-dev \
|
| 21 |
+
xz-utils \
|
| 22 |
+
tk-dev \
|
| 23 |
+
libxml2-dev \
|
| 24 |
+
libxmlsec1-dev \
|
| 25 |
+
libgdbm-dev \
|
| 26 |
+
liblzma-dev \
|
| 27 |
+
git \
|
| 28 |
+
wget \
|
| 29 |
+
curl \
|
| 30 |
+
ca-certificates \
|
| 31 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 32 |
+
|
| 33 |
+
# Install Python 3.12.9 from source
|
| 34 |
+
RUN cd /tmp && \
|
| 35 |
+
wget https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \
|
| 36 |
+
tar xzf Python-${PYTHON_VERSION}.tgz && \
|
| 37 |
+
cd Python-${PYTHON_VERSION} && \
|
| 38 |
+
./configure --enable-optimizations --with-ensurepip=install && \
|
| 39 |
+
make -j $(nproc) && \
|
| 40 |
+
make altinstall && \
|
| 41 |
+
cd / && \
|
| 42 |
+
rm -rf /tmp/Python-${PYTHON_VERSION}*
|
| 43 |
+
|
| 44 |
+
# Create symlinks for python3.12
|
| 45 |
+
RUN ln -sf /usr/local/bin/python3.12 /usr/bin/python3
|
| 46 |
+
RUN ln -sf /usr/local/bin/python3.12 /usr/bin/python
|
| 47 |
+
RUN ln -sf /usr/local/bin/pip3.12 /usr/bin/pip
|
| 48 |
+
|
| 49 |
+
# Upgrade pip
|
| 50 |
+
RUN python3 -m pip install --upgrade pip
|
| 51 |
+
|
| 52 |
+
# Set work directory
|
| 53 |
+
WORKDIR /app
|
| 54 |
+
|
| 55 |
+
# Copy requirements file
|
| 56 |
+
COPY requirements.txt .
|
| 57 |
+
|
| 58 |
+
# Install Python dependencies
|
| 59 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 60 |
+
|
| 61 |
+
# Copy application code
|
| 62 |
+
COPY subtask2_final_gradio.py .
|
| 63 |
+
|
| 64 |
+
# Create directory for model weights
|
| 65 |
+
RUN mkdir -p experiments/exp027
|
| 66 |
+
|
| 67 |
+
# Set CUDA environment variables
|
| 68 |
+
ENV CUDA_DEVICE_ORDER=PCI_BUS_ID
|
| 69 |
+
ENV CUDA_VISIBLE_DEVICES=0
|
| 70 |
+
|
| 71 |
+
# Expose port
|
| 72 |
+
EXPOSE 7860
|
| 73 |
+
|
| 74 |
+
# Create non-root user for security
|
| 75 |
+
RUN useradd -m -u 1002 appuser && chown -R appuser:appuser /app
|
| 76 |
+
USER appuser
|
| 77 |
+
|
| 78 |
+
# Command to run the application
|
| 79 |
+
CMD ["python", "subtask2_final_gradio.py"]
|
README.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AIxcellent Vibes at GermEval 2025 Shared Task on Candy Speech Detection
|
| 2 |
+
|
| 3 |
+
## Results
|
| 4 |
+
| Subtask | Submission | Model | (strict) F1 Score |
|
| 5 |
+
|---------|------------|--------------------|------------------:|
|
| 6 |
+
| 1 | 1 | Qwen3-Embedding-8B | 0.875 |
|
| 7 |
+
| 1 | 2 | XLM-RoBERTa-Large | 0.891 |
|
| 8 |
+
| 2 | 1 | GBERT-Large | 0.623 |
|
| 9 |
+
| 2 | 2 | XLM-RoBERTa-Large | 0.631 |
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Setup
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
python_version="$(cat .python-version)"
|
| 16 |
+
|
| 17 |
+
# install the interpreter if it’s missing
|
| 18 |
+
pyenv install -s "${python_version}"
|
| 19 |
+
|
| 20 |
+
# select python version for current shell
|
| 21 |
+
pyenv shell "${python_version}"
|
| 22 |
+
|
| 23 |
+
# create venv if missing
|
| 24 |
+
if [[ ! -d venv ]]; then
|
| 25 |
+
python -m venv venv
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
# activate venv & install packages
|
| 29 |
+
source venv/bin/activate
|
| 30 |
+
|
| 31 |
+
pip install -U pip setuptools wheel
|
| 32 |
+
pip install -r requirements.txt
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Diese Repository enthält den Code, mit dem die Untersuchungen der Bachelorarbeit **Flauschdetektion (GermEval 2025)**
|
| 38 |
+
im Studiengang Angewandte Mathematik und Informatik (dual) B. Sc. an der Fachhochschule Aachen durchgeführt wurden.
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
**Studiengang**
|
| 45 |
+
|
| 46 |
+
Angewandte Mathematik und Informatik B.Sc. ([AMI](https://www.fh-aachen.de/studium/angewandte-mathematik-und-informatik-bsc)) an der [FH Aachen](https://www.fh-aachen.de/), University of Applied Sciences.
|
| 47 |
+
|
| 48 |
+
**Ausbildung mit IHK Abschluss**
|
| 49 |
+
|
| 50 |
+
Mathematisch technische/-r Softwareentwickler/-in ([MaTSE](https://www.matse-ausbildung.de/startseite.html)) am Lehr- und Forschungsgebiet Igenieurhydrologie ([LFI](https://lfi.rwth-aachen.de/)) der [RWTH Aachen](https://www.rwth-aachen.de/) University.
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
span-classifier:
|
| 3 |
+
container_name: span-classifier-app
|
| 4 |
+
build:
|
| 5 |
+
context: .
|
| 6 |
+
dockerfile: Dockerfile
|
| 7 |
+
network: host
|
| 8 |
+
image: 8e6b331d0418
|
| 9 |
+
ports:
|
| 10 |
+
- "7860:7860"
|
| 11 |
+
volumes:
|
| 12 |
+
# Mount model weights directory
|
| 13 |
+
- ./experiments:/app/experiments:ro
|
| 14 |
+
# Mount cache directory for Hugging Face models
|
| 15 |
+
- /home/cthelen/.cache/huggingface:/home/appuser/.cache/huggingface
|
| 16 |
+
# Mount logs directory
|
| 17 |
+
- ./logs:/app/logs
|
| 18 |
+
environment:
|
| 19 |
+
- PYTHONUNBUFFERED=1
|
| 20 |
+
- CUDA_DEVICE_ORDER=PCI_BUS_ID
|
| 21 |
+
- CUDA_VISIBLE_DEVICES=0
|
| 22 |
+
- GRADIO_SERVER_NAME=0.0.0.0
|
| 23 |
+
- GRADIO_SERVER_PORT=7860
|
| 24 |
+
- TRANSFORMERS_CACHE=/home/appuser/.cache/huggingface
|
| 25 |
+
- TORCH_HOME=/home/appuser/.cache/torch
|
| 26 |
+
runtime: nvidia
|
| 27 |
+
deploy:
|
| 28 |
+
resources:
|
| 29 |
+
reservations:
|
| 30 |
+
devices:
|
| 31 |
+
- driver: nvidia
|
| 32 |
+
count: 1
|
| 33 |
+
capabilities: [gpu]
|
| 34 |
+
restart: unless-stopped
|
| 35 |
+
labels:
|
| 36 |
+
- "traefik.enable=true"
|
| 37 |
+
- "traefik.http.routers.demo.rule=Host(`span-classifier.gpu2.lfi.rwth-aachen.de`)"
|
| 38 |
+
- "traefik.http.routers.demo.tls=true"
|
| 39 |
+
- "traefik.http.routers.demo.tls.certresolver=letsencrypt"
|
| 40 |
+
- "com.centurylinklabs.watchtower.enable=false"
|
| 41 |
+
networks:
|
| 42 |
+
- web
|
| 43 |
+
|
| 44 |
+
networks:
|
| 45 |
+
web:
|
| 46 |
+
external: true
|
requirements.txt
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.2.2
|
| 2 |
+
accelerate==1.7.0
|
| 3 |
+
aiofiles==24.1.0
|
| 4 |
+
annotated-types==0.7.0
|
| 5 |
+
anyio==4.9.0
|
| 6 |
+
argon2-cffi==23.1.0
|
| 7 |
+
argon2-cffi-bindings==21.2.0
|
| 8 |
+
arrow==1.3.0
|
| 9 |
+
asgiref==3.8.1
|
| 10 |
+
asttokens==3.0.0
|
| 11 |
+
astunparse==1.6.3
|
| 12 |
+
async-lru==2.0.5
|
| 13 |
+
attrs==25.3.0
|
| 14 |
+
azure-ai-inference==1.0.0b9
|
| 15 |
+
azure-ai-ml==1.27.0
|
| 16 |
+
azure-common==1.1.28
|
| 17 |
+
azure-core==1.34.0
|
| 18 |
+
azure-core-tracing-opentelemetry==1.0.0b12
|
| 19 |
+
azure-identity==1.22.0
|
| 20 |
+
azure-mgmt-core==1.5.0
|
| 21 |
+
azure-monitor-opentelemetry==1.6.8
|
| 22 |
+
azure-monitor-opentelemetry-exporter==1.0.0b36
|
| 23 |
+
azure-storage-blob==12.25.1
|
| 24 |
+
azure-storage-file-datalake==12.20.0
|
| 25 |
+
azure-storage-file-share==12.21.0
|
| 26 |
+
babel==2.17.0
|
| 27 |
+
beautifulsoup4==4.13.4
|
| 28 |
+
bleach==6.2.0
|
| 29 |
+
blis==1.2.1
|
| 30 |
+
catalogue==2.0.10
|
| 31 |
+
certifi==2025.4.26
|
| 32 |
+
cffi==1.17.1
|
| 33 |
+
charset-normalizer==3.4.2
|
| 34 |
+
click==8.1.8
|
| 35 |
+
cloudpathlib==0.21.0
|
| 36 |
+
colorama==0.4.6
|
| 37 |
+
comm==0.2.2
|
| 38 |
+
confection==0.1.5
|
| 39 |
+
contourpy==1.3.2
|
| 40 |
+
cryptography==44.0.3
|
| 41 |
+
cupy-cuda12x==12.3.0
|
| 42 |
+
cycler==0.12.1
|
| 43 |
+
cymem==2.0.11
|
| 44 |
+
de_core_news_sm @ https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.8.0/de_core_news_sm-3.8.0-py3-none-any.whl#sha256=fec69fec52b1780f2d269d5af7582a5e28028738bd3190532459aeb473bfa3e7
|
| 45 |
+
debugpy==1.8.14
|
| 46 |
+
decorator==5.2.1
|
| 47 |
+
defusedxml==0.7.1
|
| 48 |
+
Deprecated==1.2.18
|
| 49 |
+
docker-pycreds==0.4.0
|
| 50 |
+
executing==2.2.0
|
| 51 |
+
fastapi==0.115.14
|
| 52 |
+
fastjsonschema==2.21.1
|
| 53 |
+
fastrlock==0.8.3
|
| 54 |
+
ffmpy==0.6.0
|
| 55 |
+
filelock==3.18.0
|
| 56 |
+
fixedint==0.1.6
|
| 57 |
+
flatbuffers==25.2.10
|
| 58 |
+
fonttools==4.57.0
|
| 59 |
+
fqdn==1.5.1
|
| 60 |
+
fsspec==2025.3.2
|
| 61 |
+
gast==0.6.0
|
| 62 |
+
gitdb==4.0.12
|
| 63 |
+
GitPython==3.1.44
|
| 64 |
+
google-pasta==0.2.0
|
| 65 |
+
gradio==5.35.0
|
| 66 |
+
gradio_client==1.10.4
|
| 67 |
+
groovy==0.1.2
|
| 68 |
+
grpcio==1.71.0
|
| 69 |
+
h11==0.16.0
|
| 70 |
+
h5py==3.13.0
|
| 71 |
+
hf-xet==1.1.0
|
| 72 |
+
httpcore==1.0.9
|
| 73 |
+
httpx==0.28.1
|
| 74 |
+
huggingface-hub==0.31.1
|
| 75 |
+
idna==3.10
|
| 76 |
+
imbalanced-learn==0.13.0
|
| 77 |
+
imblearn==0.0
|
| 78 |
+
importlib_metadata==8.6.1
|
| 79 |
+
ipykernel==6.29.5
|
| 80 |
+
ipython==8.36.0
|
| 81 |
+
ipython_pygments_lexers==1.1.1
|
| 82 |
+
isodate==0.7.2
|
| 83 |
+
isoduration==20.11.0
|
| 84 |
+
jedi==0.19.2
|
| 85 |
+
Jinja2==3.1.6
|
| 86 |
+
joblib==1.5.0
|
| 87 |
+
json5==0.12.0
|
| 88 |
+
jsonpointer==3.0.0
|
| 89 |
+
jsonschema==4.23.0
|
| 90 |
+
jsonschema-specifications==2025.4.1
|
| 91 |
+
jupyter-events==0.12.0
|
| 92 |
+
jupyter-lsp==2.2.5
|
| 93 |
+
jupyter_client==8.6.3
|
| 94 |
+
jupyter_core==5.7.2
|
| 95 |
+
jupyter_server==2.15.0
|
| 96 |
+
jupyter_server_terminals==0.5.3
|
| 97 |
+
jupyterlab==4.4.2
|
| 98 |
+
jupyterlab_pygments==0.3.0
|
| 99 |
+
jupyterlab_server==2.27.3
|
| 100 |
+
keras==3.9.2
|
| 101 |
+
kiwisolver==1.4.8
|
| 102 |
+
langcodes==3.5.0
|
| 103 |
+
language_data==1.3.0
|
| 104 |
+
libclang==18.1.1
|
| 105 |
+
marisa-trie==1.2.1
|
| 106 |
+
Markdown==3.8
|
| 107 |
+
markdown-it-py==3.0.0
|
| 108 |
+
MarkupSafe==3.0.2
|
| 109 |
+
marshmallow==3.26.1
|
| 110 |
+
matplotlib==3.10.1
|
| 111 |
+
matplotlib-inline==0.1.7
|
| 112 |
+
mdurl==0.1.2
|
| 113 |
+
mistune==3.1.3
|
| 114 |
+
ml_dtypes==0.5.1
|
| 115 |
+
mpmath==1.3.0
|
| 116 |
+
msal==1.32.3
|
| 117 |
+
msal-extensions==1.3.1
|
| 118 |
+
msrest==0.7.1
|
| 119 |
+
murmurhash==1.0.12
|
| 120 |
+
namex==0.0.9
|
| 121 |
+
nbclient==0.10.2
|
| 122 |
+
nbconvert==7.16.6
|
| 123 |
+
nbformat==5.10.4
|
| 124 |
+
nest-asyncio==1.6.0
|
| 125 |
+
networkx==3.4.2
|
| 126 |
+
notebook==7.4.2
|
| 127 |
+
notebook_shim==0.2.4
|
| 128 |
+
numpy==1.26.4
|
| 129 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 130 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
| 131 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 132 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 133 |
+
nvidia-cudnn-cu12==9.5.1.17
|
| 134 |
+
nvidia-cufft-cu12==11.3.0.4
|
| 135 |
+
nvidia-cufile-cu12==1.11.1.6
|
| 136 |
+
nvidia-curand-cu12==10.3.7.77
|
| 137 |
+
nvidia-cusolver-cu12==11.7.1.2
|
| 138 |
+
nvidia-cusparse-cu12==12.5.4.2
|
| 139 |
+
nvidia-cusparselt-cu12==0.6.3
|
| 140 |
+
nvidia-nccl-cu12==2.26.2
|
| 141 |
+
nvidia-nvjitlink-cu12==12.6.85
|
| 142 |
+
nvidia-nvtx-cu12==12.6.77
|
| 143 |
+
oauthlib==3.2.2
|
| 144 |
+
opentelemetry-api==1.31.1
|
| 145 |
+
opentelemetry-instrumentation==0.52b1
|
| 146 |
+
opentelemetry-instrumentation-asgi==0.52b1
|
| 147 |
+
opentelemetry-instrumentation-dbapi==0.52b1
|
| 148 |
+
opentelemetry-instrumentation-django==0.52b1
|
| 149 |
+
opentelemetry-instrumentation-fastapi==0.52b1
|
| 150 |
+
opentelemetry-instrumentation-flask==0.52b1
|
| 151 |
+
opentelemetry-instrumentation-psycopg2==0.52b1
|
| 152 |
+
opentelemetry-instrumentation-requests==0.52b1
|
| 153 |
+
opentelemetry-instrumentation-urllib==0.52b1
|
| 154 |
+
opentelemetry-instrumentation-urllib3==0.52b1
|
| 155 |
+
opentelemetry-instrumentation-wsgi==0.52b1
|
| 156 |
+
opentelemetry-resource-detector-azure==0.1.5
|
| 157 |
+
opentelemetry-sdk==1.31.1
|
| 158 |
+
opentelemetry-semantic-conventions==0.52b1
|
| 159 |
+
opentelemetry-util-http==0.52b1
|
| 160 |
+
opt_einsum==3.4.0
|
| 161 |
+
optree==0.15.0
|
| 162 |
+
orjson==3.10.18
|
| 163 |
+
overrides==7.7.0
|
| 164 |
+
packaging==25.0
|
| 165 |
+
pandas==2.2.3
|
| 166 |
+
pandocfilters==1.5.1
|
| 167 |
+
parso==0.8.4
|
| 168 |
+
pexpect==4.9.0
|
| 169 |
+
pillow==11.2.1
|
| 170 |
+
platformdirs==4.3.8
|
| 171 |
+
preshed==3.0.9
|
| 172 |
+
prometheus_client==0.21.1
|
| 173 |
+
prompt_toolkit==3.0.51
|
| 174 |
+
protobuf==5.29.4
|
| 175 |
+
psutil==6.1.1
|
| 176 |
+
ptyprocess==0.7.0
|
| 177 |
+
pure_eval==0.2.3
|
| 178 |
+
pycparser==2.22
|
| 179 |
+
pydantic==2.11.4
|
| 180 |
+
pydantic_core==2.33.2
|
| 181 |
+
pydash==8.0.5
|
| 182 |
+
pydub==0.25.1
|
| 183 |
+
Pygments==2.19.1
|
| 184 |
+
PyJWT==2.10.1
|
| 185 |
+
pyparsing==3.2.3
|
| 186 |
+
python-dateutil==2.9.0.post0
|
| 187 |
+
python-dotenv==1.1.0
|
| 188 |
+
python-json-logger==3.3.0
|
| 189 |
+
python-multipart==0.0.20
|
| 190 |
+
pytz==2025.2
|
| 191 |
+
PyYAML==6.0.2
|
| 192 |
+
pyzmq==26.4.0
|
| 193 |
+
referencing==0.36.2
|
| 194 |
+
regex==2024.11.6
|
| 195 |
+
requests==2.32.3
|
| 196 |
+
requests-oauthlib==2.0.0
|
| 197 |
+
rfc3339-validator==0.1.4
|
| 198 |
+
rfc3986-validator==0.1.1
|
| 199 |
+
rich==14.0.0
|
| 200 |
+
rpds-py==0.24.0
|
| 201 |
+
ruff==0.12.1
|
| 202 |
+
safehttpx==0.1.6
|
| 203 |
+
safetensors==0.5.3
|
| 204 |
+
scikit-learn==1.6.1
|
| 205 |
+
scipy==1.15.3
|
| 206 |
+
seaborn==0.13.2
|
| 207 |
+
semantic-version==2.10.0
|
| 208 |
+
Send2Trash==1.8.3
|
| 209 |
+
sentry-sdk==2.28.0
|
| 210 |
+
setproctitle==1.3.6
|
| 211 |
+
setuptools==80.3.1
|
| 212 |
+
shellingham==1.5.4
|
| 213 |
+
six==1.17.0
|
| 214 |
+
sklearn-compat==0.1.3
|
| 215 |
+
smart-open==7.1.0
|
| 216 |
+
smmap==5.0.2
|
| 217 |
+
sniffio==1.3.1
|
| 218 |
+
soupsieve==2.7
|
| 219 |
+
spacy==3.8.5
|
| 220 |
+
spacy-legacy==3.0.12
|
| 221 |
+
spacy-loggers==1.0.5
|
| 222 |
+
srsly==2.5.1
|
| 223 |
+
stack-data==0.6.3
|
| 224 |
+
starlette==0.46.2
|
| 225 |
+
strictyaml==1.7.3
|
| 226 |
+
sympy==1.14.0
|
| 227 |
+
tensorboard==2.19.0
|
| 228 |
+
tensorboard-data-server==0.7.2
|
| 229 |
+
tensorflow==2.19.0
|
| 230 |
+
termcolor==3.1.0
|
| 231 |
+
terminado==0.18.1
|
| 232 |
+
tf_keras==2.19.0
|
| 233 |
+
thinc==8.3.4
|
| 234 |
+
threadpoolctl==3.6.0
|
| 235 |
+
tinycss2==1.4.0
|
| 236 |
+
tokenizers==0.21.1
|
| 237 |
+
tomlkit==0.13.3
|
| 238 |
+
torch==2.7.0
|
| 239 |
+
tornado==6.4.2
|
| 240 |
+
tqdm==4.67.1
|
| 241 |
+
traitlets==5.14.3
|
| 242 |
+
transformers==4.51.3
|
| 243 |
+
triton==3.3.0
|
| 244 |
+
typer==0.15.3
|
| 245 |
+
types-python-dateutil==2.9.0.20241206
|
| 246 |
+
typing-inspection==0.4.0
|
| 247 |
+
typing_extensions==4.13.2
|
| 248 |
+
tzdata==2025.2
|
| 249 |
+
uri-template==1.3.0
|
| 250 |
+
urllib3==2.4.0
|
| 251 |
+
uvicorn==0.35.0
|
| 252 |
+
wandb==0.19.11
|
| 253 |
+
wasabi==1.1.3
|
| 254 |
+
wcwidth==0.2.13
|
| 255 |
+
weasel==0.4.1
|
| 256 |
+
webcolors==24.11.1
|
| 257 |
+
webencodings==0.5.1
|
| 258 |
+
websocket-client==1.8.0
|
| 259 |
+
websockets==15.0.1
|
| 260 |
+
Werkzeug==3.1.3
|
| 261 |
+
wheel==0.45.1
|
| 262 |
+
wrapt==1.17.2
|
| 263 |
+
zipp==3.21.0
|
setup_env.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
python_version="$(cat .python-version)"
|
| 5 |
+
|
| 6 |
+
# 1. Install the interpreter if it’s missing
|
| 7 |
+
pyenv install -s "${python_version}"
|
| 8 |
+
|
| 9 |
+
# select python version for current shell
|
| 10 |
+
pyenv shell "${python_version}"
|
| 11 |
+
|
| 12 |
+
# create venv if missing
|
| 13 |
+
if [[ ! -d venv ]]; then
|
| 14 |
+
python -m venv venv
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
# 3. Activate venv & install packages
|
| 18 |
+
source venv/bin/activate
|
| 19 |
+
|
| 20 |
+
pip install -U pip setuptools wheel
|
| 21 |
+
pip install -r requirements.txt
|
subtask2_final_gradio.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoTokenizer,
|
| 7 |
+
BertForTokenClassification,
|
| 8 |
+
AutoModelForTokenClassification,
|
| 9 |
+
pipeline
|
| 10 |
+
)
|
| 11 |
+
import torch
|
| 12 |
+
import os
|
| 13 |
+
import seaborn as sns
|
| 14 |
+
from matplotlib.colors import to_hex
|
| 15 |
+
import html
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 19 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
| 20 |
+
|
| 21 |
+
class SpanClassifierWithStrictF1:
|
| 22 |
+
def __init__(self, model_name="deepset/gbert-base"):
|
| 23 |
+
self.model_name = model_name
|
| 24 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 25 |
+
|
| 26 |
+
self.labels =[
|
| 27 |
+
"O",
|
| 28 |
+
"B-positive feedback", "B-compliment", "B-affection declaration", "B-encouragement", "B-gratitude", "B-agreement", "B-ambiguous", "B-implicit", "B-group membership", "B-sympathy",
|
| 29 |
+
"I-positive feedback", "I-compliment", "I-affection declaration", "I-encouragement", "I-gratitude", "I-agreement", "I-ambiguous", "I-implicit", "I-group membership", "I-sympathy"
|
| 30 |
+
]
|
| 31 |
+
self.label2id = {label: i for i, label in enumerate(self.labels)}
|
| 32 |
+
self.id2label = {i: label for i, label in enumerate(self.labels)}
|
| 33 |
+
|
| 34 |
+
def create_dataset(self, comments_df, spans_df):
|
| 35 |
+
"""Erstelle Dataset mit BIO-Labels und speichere Evaluation-Daten"""
|
| 36 |
+
examples = []
|
| 37 |
+
eval_data = [] # Für Strict F1 Berechnung
|
| 38 |
+
|
| 39 |
+
spans_grouped = spans_df.groupby(['document', 'comment_id'])
|
| 40 |
+
|
| 41 |
+
for _, row in comments_df.iterrows():
|
| 42 |
+
text = row['comment']
|
| 43 |
+
document = row['document']
|
| 44 |
+
comment_id = row['comment_id']
|
| 45 |
+
key = (document, comment_id)
|
| 46 |
+
|
| 47 |
+
# True spans für diesen Kommentar
|
| 48 |
+
if key in spans_grouped.groups:
|
| 49 |
+
true_spans = [(span_type, int(start), int(end))
|
| 50 |
+
for span_type, start, end in
|
| 51 |
+
spans_grouped.get_group(key)[['type', 'start', 'end']].values]
|
| 52 |
+
else:
|
| 53 |
+
true_spans = []
|
| 54 |
+
|
| 55 |
+
# Tokenisierung
|
| 56 |
+
tokenized = self.tokenizer(text, truncation=True, max_length=512,
|
| 57 |
+
return_offsets_mapping=True)
|
| 58 |
+
|
| 59 |
+
# BIO-Labels erstellen
|
| 60 |
+
labels = self._create_bio_labels(tokenized['offset_mapping'],
|
| 61 |
+
spans_grouped.get_group(key)[['start', 'end', 'type']].values
|
| 62 |
+
if key in spans_grouped.groups else [])
|
| 63 |
+
|
| 64 |
+
examples.append({
|
| 65 |
+
'input_ids': tokenized['input_ids'],
|
| 66 |
+
'attention_mask': tokenized['attention_mask'],
|
| 67 |
+
'labels': labels
|
| 68 |
+
})
|
| 69 |
+
|
| 70 |
+
# Evaluation-Daten speichern
|
| 71 |
+
eval_data.append({
|
| 72 |
+
'text': text,
|
| 73 |
+
'offset_mapping': tokenized['offset_mapping'],
|
| 74 |
+
'true_spans': true_spans,
|
| 75 |
+
'document': document,
|
| 76 |
+
'comment_id': comment_id
|
| 77 |
+
})
|
| 78 |
+
|
| 79 |
+
return examples, eval_data
|
| 80 |
+
|
| 81 |
+
def _create_bio_labels(self, offset_mapping, spans):
|
| 82 |
+
"""Erstelle BIO-Labels für Tokens"""
|
| 83 |
+
labels = [0] * len(offset_mapping) # 0 = "O"
|
| 84 |
+
|
| 85 |
+
for start, end, type_label in spans:
|
| 86 |
+
for i, (token_start, token_end) in enumerate(offset_mapping):
|
| 87 |
+
if token_start is None: # Spezielle Tokens
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
# Token überlappt mit Span
|
| 91 |
+
if token_start < end and token_end > start:
|
| 92 |
+
if token_start <= start:
|
| 93 |
+
labels[i] = self.label2id[f'B-{type_label}'] # B-compliment
|
| 94 |
+
else:
|
| 95 |
+
labels[i] = self.label2id[f'I-{type_label}'] # I-compliment
|
| 96 |
+
|
| 97 |
+
return labels
|
| 98 |
+
|
| 99 |
+
def compute_metrics(self, eval_pred):
|
| 100 |
+
"""Berechne Strict F1 für Trainer"""
|
| 101 |
+
predictions, labels = eval_pred
|
| 102 |
+
predictions = np.argmax(predictions, axis=2)
|
| 103 |
+
|
| 104 |
+
# Konvertiere Vorhersagen zu Spans
|
| 105 |
+
batch_pred_spans = []
|
| 106 |
+
batch_true_spans = []
|
| 107 |
+
|
| 108 |
+
for i, (pred_seq, label_seq) in enumerate(zip(predictions, labels)):
|
| 109 |
+
# Evaluation-Daten für dieses Beispiel
|
| 110 |
+
if i < len(self.current_eval_data):
|
| 111 |
+
eval_item = self.current_eval_data[i]
|
| 112 |
+
text = eval_item['text']
|
| 113 |
+
offset_mapping = eval_item['offset_mapping']
|
| 114 |
+
true_spans = eval_item['true_spans']
|
| 115 |
+
|
| 116 |
+
# Filtere gültige Vorhersagen (keine Padding-Tokens)
|
| 117 |
+
valid_predictions = []
|
| 118 |
+
valid_offsets = []
|
| 119 |
+
|
| 120 |
+
for j, (pred_label, true_label) in enumerate(zip(pred_seq, label_seq)):
|
| 121 |
+
if true_label != -100 and j < len(offset_mapping):
|
| 122 |
+
valid_predictions.append(pred_label)
|
| 123 |
+
valid_offsets.append(offset_mapping[j])
|
| 124 |
+
|
| 125 |
+
# Konvertiere zu Spans
|
| 126 |
+
pred_spans = self._predictions_to_spans(valid_predictions, valid_offsets, text)
|
| 127 |
+
pred_spans_tuples = [(span['type'], span['start'], span['end']) for span in pred_spans]
|
| 128 |
+
|
| 129 |
+
batch_pred_spans.append(pred_spans_tuples)
|
| 130 |
+
batch_true_spans.append(true_spans)
|
| 131 |
+
|
| 132 |
+
# Berechne Strict F1
|
| 133 |
+
strict_f1, strict_precision, strict_recall, tp, fp, fn = self._calculate_strict_f1(
|
| 134 |
+
batch_true_spans, batch_pred_spans
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
torch.cuda.memory.empty_cache()
|
| 138 |
+
|
| 139 |
+
return {
|
| 140 |
+
"strict_f1": torch.tensor(strict_f1),
|
| 141 |
+
"strict_precision": torch.tensor(strict_precision),
|
| 142 |
+
"strict_recall": torch.tensor(strict_recall),
|
| 143 |
+
"true_positives": torch.tensor(tp),
|
| 144 |
+
"false_positives": torch.tensor(fp),
|
| 145 |
+
"false_negatives": torch.tensor(fn)
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
def _calculate_strict_f1(self, true_spans_list, pred_spans_list):
|
| 149 |
+
"""Berechne Strict F1 über alle Kommentare"""
|
| 150 |
+
tp, fp, fn = 0, 0, 0
|
| 151 |
+
|
| 152 |
+
for true_spans, pred_spans in zip(true_spans_list, pred_spans_list):
|
| 153 |
+
# Finde exakte Matches (Typ und Span müssen übereinstimmen)
|
| 154 |
+
matches = self._find_exact_matches(true_spans, pred_spans)
|
| 155 |
+
|
| 156 |
+
tp += len(matches)
|
| 157 |
+
fp += len(pred_spans) - len(matches)
|
| 158 |
+
fn += len(true_spans) - len(matches)
|
| 159 |
+
|
| 160 |
+
# Berechne Metriken
|
| 161 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 162 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 163 |
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 164 |
+
|
| 165 |
+
return f1, precision, recall, tp, fp, fn
|
| 166 |
+
|
| 167 |
+
def _find_exact_matches(self, true_spans, pred_spans):
|
| 168 |
+
"""Finde exakte Matches zwischen True und Predicted Spans"""
|
| 169 |
+
matches = []
|
| 170 |
+
used_pred = set()
|
| 171 |
+
|
| 172 |
+
for true_span in true_spans:
|
| 173 |
+
for i, pred_span in enumerate(pred_spans):
|
| 174 |
+
if i not in used_pred and true_span == pred_span:
|
| 175 |
+
matches.append((true_span, pred_span))
|
| 176 |
+
used_pred.add(i)
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
return matches
|
| 180 |
+
|
| 181 |
+
def _predictions_to_spans(self, predicted_labels, offset_mapping, text):
|
| 182 |
+
"""Konvertiere Token-Vorhersagen zu Spans"""
|
| 183 |
+
spans = []
|
| 184 |
+
current_span = None
|
| 185 |
+
|
| 186 |
+
for i, label_id in enumerate(predicted_labels):
|
| 187 |
+
if i >= len(offset_mapping):
|
| 188 |
+
break
|
| 189 |
+
|
| 190 |
+
label = self.id2label[label_id]
|
| 191 |
+
token_start, token_end = offset_mapping[i]
|
| 192 |
+
|
| 193 |
+
if token_start is None:
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
if label.startswith('B-'):
|
| 197 |
+
if current_span:
|
| 198 |
+
spans.append(current_span)
|
| 199 |
+
current_span = {
|
| 200 |
+
'type': label[2:],
|
| 201 |
+
'start': token_start,
|
| 202 |
+
'end': token_end,
|
| 203 |
+
'text': text[token_start:token_end]
|
| 204 |
+
}
|
| 205 |
+
elif label.startswith('I-') and current_span:
|
| 206 |
+
current_span['end'] = token_end
|
| 207 |
+
current_span['text'] = text[current_span['start']:current_span['end']]
|
| 208 |
+
else:
|
| 209 |
+
if current_span:
|
| 210 |
+
spans.append(current_span)
|
| 211 |
+
current_span = None
|
| 212 |
+
|
| 213 |
+
if current_span:
|
| 214 |
+
spans.append(current_span)
|
| 215 |
+
|
| 216 |
+
return spans
|
| 217 |
+
|
| 218 |
+
def predict(self, texts):
|
| 219 |
+
"""Vorhersage für neue Texte"""
|
| 220 |
+
if not hasattr(self, 'model'):
|
| 221 |
+
raise ValueError("Modell muss erst trainiert werden!")
|
| 222 |
+
|
| 223 |
+
predictions = []
|
| 224 |
+
device = next(self.model.parameters()).device
|
| 225 |
+
|
| 226 |
+
for text in texts:
|
| 227 |
+
# Tokenisierung
|
| 228 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True,
|
| 229 |
+
max_length=512, return_offsets_mapping=True)
|
| 230 |
+
|
| 231 |
+
offset_mapping = inputs.pop('offset_mapping')
|
| 232 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 233 |
+
|
| 234 |
+
# Vorhersage
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
outputs = self.model(**inputs)
|
| 237 |
+
|
| 238 |
+
predicted_labels = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()
|
| 239 |
+
|
| 240 |
+
# Spans extrahieren
|
| 241 |
+
spans = self._predictions_to_spans(predicted_labels, offset_mapping[0], text)
|
| 242 |
+
predictions.append({'text': text, 'spans': spans})
|
| 243 |
+
|
| 244 |
+
return predictions
|
| 245 |
+
|
| 246 |
+
def evaluate_strict_f1(self, comments_df, spans_df):
|
| 247 |
+
"""Evaluiere Strict F1 auf Test-Daten"""
|
| 248 |
+
if not hasattr(self, 'model'):
|
| 249 |
+
raise ValueError("Modell muss erst trainiert werden!")
|
| 250 |
+
|
| 251 |
+
print("Evaluiere Strict F1...")
|
| 252 |
+
|
| 253 |
+
# Vorhersagen für alle Kommentare
|
| 254 |
+
texts = comments_df['comment'].tolist()
|
| 255 |
+
predictions = self.predict(texts)
|
| 256 |
+
|
| 257 |
+
# Organisiere True Spans
|
| 258 |
+
spans_grouped = spans_df.groupby(['document', 'comment_id'])
|
| 259 |
+
true_spans_dict = {}
|
| 260 |
+
pred_spans_dict = {}
|
| 261 |
+
|
| 262 |
+
for i, (_, row) in enumerate(comments_df.iterrows()):
|
| 263 |
+
key = (row['document'], row['comment_id'])
|
| 264 |
+
|
| 265 |
+
# True spans
|
| 266 |
+
if key in spans_grouped.groups:
|
| 267 |
+
true_spans = [(span_type, int(start), int(end))
|
| 268 |
+
for span_type, start, end in
|
| 269 |
+
spans_grouped.get_group(key)[['type', 'start', 'end']].values]
|
| 270 |
+
else:
|
| 271 |
+
true_spans = []
|
| 272 |
+
|
| 273 |
+
# Predicted spans
|
| 274 |
+
pred_spans = [(span['type'], span['start'], span['end'])
|
| 275 |
+
for span in predictions[i]['spans']]
|
| 276 |
+
|
| 277 |
+
true_spans_dict[key] = true_spans
|
| 278 |
+
pred_spans_dict[key] = pred_spans
|
| 279 |
+
|
| 280 |
+
# Berechne Strict F1
|
| 281 |
+
all_true_spans = list(true_spans_dict.values())
|
| 282 |
+
all_pred_spans = list(pred_spans_dict.values())
|
| 283 |
+
|
| 284 |
+
f1, precision, recall, tp, fp, fn = self._calculate_strict_f1(all_true_spans, all_pred_spans)
|
| 285 |
+
|
| 286 |
+
print(f"\nStrict F1 Ergebnisse:")
|
| 287 |
+
print(f"Precision: {precision:.4f}")
|
| 288 |
+
print(f"Recall: {recall:.4f}")
|
| 289 |
+
print(f"F1-Score: {f1:.4f}")
|
| 290 |
+
print(f"True Positives: {tp}, False Positives: {fp}, False Negatives: {fn}")
|
| 291 |
+
|
| 292 |
+
return {
|
| 293 |
+
'strict_f1': f1,
|
| 294 |
+
'strict_precision': precision,
|
| 295 |
+
'strict_recall': recall,
|
| 296 |
+
'true_positives': tp,
|
| 297 |
+
'false_positives': fp,
|
| 298 |
+
'false_negatives': fn
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
def convert_spans(row):
|
| 302 |
+
spans = row['predicted_spans']
|
| 303 |
+
document = row['document']
|
| 304 |
+
comment_id = row['comment_id']
|
| 305 |
+
return [{'document': document, 'comment_id': comment_id, 'type': span['type'], 'start': span['start'], 'end': span['end']} for span in spans]
|
| 306 |
+
|
| 307 |
+
def pred_to_spans(row):
|
| 308 |
+
predicted_labels, offset_mapping, text = row['predicted_labels'], row['offset_mapping'], row['comment']
|
| 309 |
+
return [classifier._predictions_to_spans(predicted_labels, offset_mapping, text)]
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def create_highlighted_html(text, spans):
|
| 313 |
+
"""Erstelle HTML mit hervorgehobenen Spans"""
|
| 314 |
+
if not spans:
|
| 315 |
+
return html.escape(text)
|
| 316 |
+
|
| 317 |
+
# Definiere Farben für verschiedene Span-Typen
|
| 318 |
+
colors = {
|
| 319 |
+
'positive feedback': '#FFE5E5',
|
| 320 |
+
'compliment': '#E5F3FF',
|
| 321 |
+
'affection declaration': '#FFE5F3',
|
| 322 |
+
'encouragement': '#E5FFE5',
|
| 323 |
+
'gratitude': '#FFF5E5',
|
| 324 |
+
'agreement': '#F0E5FF',
|
| 325 |
+
'ambiguous': '#E5E5E5',
|
| 326 |
+
'implicit': '#E5FFFF',
|
| 327 |
+
'group membership': '#FFFFE5',
|
| 328 |
+
'sympathy': '#F5E5FF'
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
colors = {
|
| 332 |
+
'positive feedback': '#8dd3c7', # tealfarbenes Pastell
|
| 333 |
+
'compliment': '#ffffb3', # helles Pastellgelb
|
| 334 |
+
'affection declaration': '#bebada', # fliederfarbenes Pastell
|
| 335 |
+
'encouragement': '#fb8072', # lachsfarbenes Pastell
|
| 336 |
+
'gratitude': '#80b1d3', # himmelblaues Pastell
|
| 337 |
+
'agreement': '#fdb462', # pfirsichfarbenes Pastell
|
| 338 |
+
'ambiguous': '#d9d9d9', # neutrales Pastellgrau
|
| 339 |
+
'implicit': '#fccde5', # roséfarbenes Pastell
|
| 340 |
+
'group membership': '#b3de69', # lindgrünes Pastell
|
| 341 |
+
'sympathy': '#bc80bd' # lavendelfarbenes Pastell
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
# Sortiere Spans nach Start-Position
|
| 345 |
+
sorted_spans = sorted(spans, key=lambda x: x['start'])
|
| 346 |
+
|
| 347 |
+
html_parts = []
|
| 348 |
+
last_end = 0
|
| 349 |
+
|
| 350 |
+
for span in sorted_spans:
|
| 351 |
+
# Text vor dem Span
|
| 352 |
+
if span['start'] > last_end:
|
| 353 |
+
html_parts.append(html.escape(text[last_end:span['start']]))
|
| 354 |
+
|
| 355 |
+
# Hervorgehobener Span
|
| 356 |
+
color = colors.get(span['type'], '#EEEEEE')
|
| 357 |
+
span_text = html.escape(text[span['start']:span['end']])
|
| 358 |
+
html_parts.append(
|
| 359 |
+
f'<span style="background-color: {color}; padding: 2px 4px; border-radius: 3px; margin: 1px; display: inline-block;" title="{span["type"]}">{span_text}</span>')
|
| 360 |
+
|
| 361 |
+
last_end = span['end']
|
| 362 |
+
|
| 363 |
+
# Restlicher Text
|
| 364 |
+
if last_end < len(text):
|
| 365 |
+
html_parts.append(html.escape(text[last_end:]))
|
| 366 |
+
|
| 367 |
+
return ''.join(html_parts)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def create_legend():
|
| 371 |
+
"""Erstelle eine Legende für die Span-Typen"""
|
| 372 |
+
#colors = {
|
| 373 |
+
# 'positive feedback': '#FFE5E5',
|
| 374 |
+
# 'compliment': '#E5F3FF',
|
| 375 |
+
# 'affection declaration': '#FFE5F3',
|
| 376 |
+
# 'encouragement': '#E5FFE5',
|
| 377 |
+
# 'gratitude': '#FFF5E5',
|
| 378 |
+
# 'agreement': '#F0E5FF',
|
| 379 |
+
# 'ambiguous': '#E5E5E5',
|
| 380 |
+
# 'implicit': '#E5FFFF',
|
| 381 |
+
# 'group membership': '#FFFFE5',
|
| 382 |
+
# 'sympathy': '#F5E5FF'
|
| 383 |
+
#}
|
| 384 |
+
|
| 385 |
+
colors = {
|
| 386 |
+
'positive feedback': '#8dd3c7', # tealfarbenes Pastell
|
| 387 |
+
'compliment': '#ffffb3', # helles Pastellgelb
|
| 388 |
+
'affection declaration': '#bebada', # fliederfarbenes Pastell
|
| 389 |
+
'encouragement': '#fb8072', # lachsfarbenes Pastell
|
| 390 |
+
'gratitude': '#80b1d3', # himmelblaues Pastell
|
| 391 |
+
'agreement': '#fdb462', # pfirsichfarbenes Pastell
|
| 392 |
+
'ambiguous': '#d9d9d9', # neutrales Pastellgrau
|
| 393 |
+
'implicit': '#fccde5', # roséfarbenes Pastell
|
| 394 |
+
'group membership': '#b3de69', # lindgrünes Pastell
|
| 395 |
+
'sympathy': '#bc80bd' # lavendelfarbenes Pastell
|
| 396 |
+
}
|
| 397 |
+
legend_html = "<div style='margin: 10px 0;'><h4>Candy Speech Types:</h4>"
|
| 398 |
+
for span_type, color in colors.items():
|
| 399 |
+
legend_html += f'<span style="background-color: {color}; padding: 4px 8px; border-radius: 3px; margin: 2px; display: inline-block;">{span_type}</span>'
|
| 400 |
+
legend_html += "</div>"
|
| 401 |
+
|
| 402 |
+
return legend_html
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def analyze_text(text):
|
| 406 |
+
"""Analysiere Text und gebe Ergebnisse zurück"""
|
| 407 |
+
if not text.strip():
|
| 408 |
+
return "Bitte geben Sie einen Text ein.", "", ""
|
| 409 |
+
|
| 410 |
+
try:
|
| 411 |
+
# Vorhersage mit dem Classifier
|
| 412 |
+
predictions = classifier.predict([text])
|
| 413 |
+
spans = predictions[0]['spans']
|
| 414 |
+
|
| 415 |
+
# Erstelle HTML mit hervorgehobenen Spans
|
| 416 |
+
highlighted_html = create_highlighted_html(text, spans)
|
| 417 |
+
|
| 418 |
+
# Erstelle Zusammenfassung
|
| 419 |
+
summary = create_summary(spans)
|
| 420 |
+
|
| 421 |
+
# Erstelle detaillierte Span-Informationen
|
| 422 |
+
details = create_details(spans, text)
|
| 423 |
+
|
| 424 |
+
return highlighted_html, summary, details
|
| 425 |
+
|
| 426 |
+
except Exception as e:
|
| 427 |
+
return f"Fehler bei der Analyse: {str(e)}", "", ""
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def create_summary(spans):
|
| 431 |
+
"""Erstelle eine Zusammenfassung der gefundenen Spans"""
|
| 432 |
+
if not spans:
|
| 433 |
+
return "Keine Spans gefunden."
|
| 434 |
+
|
| 435 |
+
return ""
|
| 436 |
+
|
| 437 |
+
span_counts = {}
|
| 438 |
+
for span in spans:
|
| 439 |
+
span_type = span['type']
|
| 440 |
+
span_counts[span_type] = span_counts.get(span_type, 0) + 1
|
| 441 |
+
|
| 442 |
+
summary_lines = [f"**Insgesamt {len(spans)} Spans gefunden:**"]
|
| 443 |
+
for span_type, count in sorted(span_counts.items()):
|
| 444 |
+
summary_lines.append(f"- {span_type}: {count}")
|
| 445 |
+
|
| 446 |
+
return "\n".join(summary_lines)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def create_details(spans, text):
|
| 450 |
+
"""Erstelle detaillierte Informationen über die Spans"""
|
| 451 |
+
if not spans:
|
| 452 |
+
return "Keine Details verfügbar."
|
| 453 |
+
|
| 454 |
+
details_lines = ["**Span-Informationen:**"]
|
| 455 |
+
for i, span in enumerate(spans, 1):
|
| 456 |
+
span_text = text[span['start']:span['end']]
|
| 457 |
+
details_lines.append(f"{i}. **{span['type']}** ({span['start']}-{span['end']}): \"{span_text}\"")
|
| 458 |
+
|
| 459 |
+
return "\n".join(details_lines)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def load_example_texts():
|
| 463 |
+
"""Lade Beispieltexte für die Demo"""
|
| 464 |
+
examples = [
|
| 465 |
+
"Ich stimme allen zu die denken das Roman und Heiko super sind !!!!",
|
| 466 |
+
"da geb ich dir recht ich stehe dir bei die sind einfach nur geil !",
|
| 467 |
+
"OMG, ihr seid einfach der absolute Hammer! 🤩 Eure Videos bringen mich jedes Mal zum Lachen und geben mir so viel Motivation – eure Stimmen klingen mega, eure Parodien sind lustiger als das Original und ihr seht dabei unfassbar toll aus! 😂👌 Bitte macht weiter so! ❤️🎉",
|
| 468 |
+
"Das ist ein wirklich toller Beitrag! Vielen Dank für diese hilfreichen Informationen.",
|
| 469 |
+
"Du bist so klug und hilfreich. Ich bin dir sehr dankbar für deine Unterstützung.",
|
| 470 |
+
"Großartige Arbeit! Das motiviert mich wirklich weiterzumachen.",
|
| 471 |
+
"Das tut mir leid zu hören. Ich hoffe, es wird bald besser für dich.",
|
| 472 |
+
]
|
| 473 |
+
return examples
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# Erstelle die Gradio-Interface
|
| 477 |
+
def create_gradio_interface():
|
| 478 |
+
"""Erstelle die Gradio-Benutzeroberfläche"""
|
| 479 |
+
|
| 480 |
+
with gr.Blocks(title="Span Classifier Demo", theme=gr.themes.Soft()) as demo:
|
| 481 |
+
gr.HTML("""
|
| 482 |
+
<div style="text-align: center; margin: 20px 0;">
|
| 483 |
+
<h1>🍭 Candy Speech Span Classifier</h1>
|
| 484 |
+
<p>Analysieren Sie Texte und identifizieren Sie verschiedene Arten positiver Kommunikation.</p>
|
| 485 |
+
</div>
|
| 486 |
+
""")
|
| 487 |
+
|
| 488 |
+
# Legende
|
| 489 |
+
gr.HTML(create_legend())
|
| 490 |
+
|
| 491 |
+
with gr.Row():
|
| 492 |
+
with gr.Column(scale=2):
|
| 493 |
+
# Input
|
| 494 |
+
text_input = gr.Textbox(
|
| 495 |
+
label="Text eingeben",
|
| 496 |
+
placeholder="Geben Sie hier den Text ein, den Sie analysieren möchten...",
|
| 497 |
+
lines=5
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Buttons
|
| 501 |
+
with gr.Row():
|
| 502 |
+
analyze_btn = gr.Button("Analysieren", variant="primary")
|
| 503 |
+
clear_btn = gr.Button("Löschen", variant="secondary")
|
| 504 |
+
|
| 505 |
+
# Beispiele
|
| 506 |
+
gr.Examples(
|
| 507 |
+
examples=load_example_texts(),
|
| 508 |
+
inputs=text_input,
|
| 509 |
+
label="Beispieltexte"
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
gr.Examples(
|
| 513 |
+
examples=[ "Bin wegen dir vegan geworden DANKE🫶 Du bist einzigartig und mach bitte weiter 🤍 🧚♀️",
|
| 514 |
+
"Danke für deine tolle Arbeit, auch schön, dass du den Permazidbegriff so wunderbar verwendest <3 Das hast du wirklich alles exzellent gemacht!",
|
| 515 |
+
"Rafaella Raab ist eine Ikone! Wir sollten alle mehr Tierrechtsaktivismus machen. Höchster Respekt!",
|
| 516 |
+
],
|
| 517 |
+
inputs=text_input,
|
| 518 |
+
label="Out-of-Distribution Examples (Rafaella Raab)",
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
gr.Examples(
|
| 522 |
+
examples=[
|
| 523 |
+
"Tolles Video! Hab es einfach stumm geschaltet und tatsächlich eine gute Zeit gehabt.", #aderserial
|
| 524 |
+
"Auf lautlos ballert der Track noch geiler. 🙏🏻",
|
| 525 |
+
],
|
| 526 |
+
inputs=text_input,
|
| 527 |
+
label="Adversarial Example (Sarcasm)"
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
with gr.Column(scale=2):
|
| 531 |
+
# Outputs
|
| 532 |
+
highlighted_output = gr.HTML(
|
| 533 |
+
label="Analysierter Text",
|
| 534 |
+
show_label=True
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
summary_output = gr.Markdown(
|
| 538 |
+
label="Zusammenfassung",
|
| 539 |
+
show_label=True
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
details_output = gr.Markdown(
|
| 543 |
+
label="Details",
|
| 544 |
+
show_label=True
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
# Info-Bereich
|
| 548 |
+
with gr.Accordion("ℹ️ Informationen zum Modell", open=False):
|
| 549 |
+
gr.Markdown("""
|
| 550 |
+
### Über dieses Modell
|
| 551 |
+
|
| 552 |
+
Dieses Modell identifiziert verschiedene Arten positiver Kommunikation in Texten:
|
| 553 |
+
|
| 554 |
+
- **Positive Feedback**: Allgemein positive Rückmeldungen
|
| 555 |
+
- **Compliment**: Direkte Komplimente
|
| 556 |
+
- **Affection Declaration**: Liebesbekundungen oder Zuneigung
|
| 557 |
+
- **Encouragement**: Ermutigung und Motivation
|
| 558 |
+
- **Gratitude**: Dankbarkeit und Wertschätzung
|
| 559 |
+
- **Agreement**: Zustimmung und Einverständnis
|
| 560 |
+
- **Ambiguous**: Mehrdeutige positive Aussagen
|
| 561 |
+
- **Implicit**: Implizite positive Kommunikation
|
| 562 |
+
- **Group Membership**: Zugehörigkeitsgefühl
|
| 563 |
+
- **Sympathy**: Mitgefühl und Empathie
|
| 564 |
+
|
| 565 |
+
### Verwendung
|
| 566 |
+
1. Geben Sie einen Text in das Eingabefeld ein
|
| 567 |
+
2. Klicken Sie auf "Analysieren"
|
| 568 |
+
3. Betrachten Sie die hervorgehobenen Spans im analysierten Text
|
| 569 |
+
4. Überprüfen Sie die Zusammenfassung und Details
|
| 570 |
+
""")
|
| 571 |
+
|
| 572 |
+
# Event-Handler
|
| 573 |
+
analyze_btn.click(
|
| 574 |
+
fn=analyze_text,
|
| 575 |
+
inputs=text_input,
|
| 576 |
+
outputs=[highlighted_output, summary_output, details_output]
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
clear_btn.click(
|
| 580 |
+
fn=lambda: ("", "", "", ""),
|
| 581 |
+
outputs=[text_input, highlighted_output, summary_output, details_output]
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Auto-Analyse bei Beispiel-Auswahl
|
| 585 |
+
text_input.change(
|
| 586 |
+
fn=analyze_text,
|
| 587 |
+
inputs=text_input,
|
| 588 |
+
outputs=[highlighted_output, summary_output, details_output]
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
return demo
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
if __name__ == "__main__":
|
| 596 |
+
classifier = SpanClassifierWithStrictF1('xlm-roberta-large')
|
| 597 |
+
|
| 598 |
+
classifier.model = AutoModelForTokenClassification.from_pretrained(
|
| 599 |
+
'xlm-roberta-large',
|
| 600 |
+
num_labels=len(classifier.labels),
|
| 601 |
+
id2label=classifier.id2label,
|
| 602 |
+
label2id=classifier.label2id
|
| 603 |
+
)
|
| 604 |
+
classifier.model.load_state_dict(torch.load('./experiments/exp027/exp027-2_retraining_final_model.pth'))
|
| 605 |
+
classifier.model.eval()
|
| 606 |
+
|
| 607 |
+
print("Modell geladen! Starte Gradio-Interface...")
|
| 608 |
+
|
| 609 |
+
# Erstelle und starte die Demo
|
| 610 |
+
demo = create_gradio_interface()
|
| 611 |
+
|
| 612 |
+
# Starte die Demo
|
| 613 |
+
demo.launch(
|
| 614 |
+
server_name="0.0.0.0", # Für externen Zugriff
|
| 615 |
+
server_port=7860,
|
| 616 |
+
debug=True,
|
| 617 |
+
show_error=True
|
| 618 |
+
)
|
subtask_1/exp019-4.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
|
| 4 |
+
### Experiment 019-4
|
| 5 |
+
# - Model: Qwen/Qwen3-Embedding-8B
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, make_scorer, classification_report
|
| 9 |
+
from sklearn.model_selection import StratifiedKFold, train_test_split, GridSearchCV
|
| 10 |
+
from sklearn.pipeline import Pipeline
|
| 11 |
+
from sklearn.preprocessing import StandardScaler
|
| 12 |
+
from sklearn.svm import SVC
|
| 13 |
+
import time
|
| 14 |
+
import pickle
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import torch
|
| 18 |
+
from torch import Tensor
|
| 19 |
+
from transformers import AutoModel, AutoTokenizer
|
| 20 |
+
from transformers.utils import is_flash_attn_2_available
|
| 21 |
+
import wandb
|
| 22 |
+
from wandb import AlertLevel
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 27 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 28 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
|
| 29 |
+
os.environ["WANDB_PROJECT"] = "GermEval2025-Substask1"
|
| 30 |
+
os.environ["WANDB_LOG_MODEL"] = "false"
|
| 31 |
+
|
| 32 |
+
if torch.cuda.is_available():
|
| 33 |
+
device = torch.device('cuda')
|
| 34 |
+
else:
|
| 35 |
+
device = torch.device('cpu')
|
| 36 |
+
print("CUDA not available, using CPU")
|
| 37 |
+
|
| 38 |
+
experiment_name = "exp019-4"
|
| 39 |
+
|
| 40 |
+
testing_mode = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Load data
|
| 44 |
+
comments = pd.read_csv("../../share-GermEval2025-data/Data/training data/comments.csv")
|
| 45 |
+
task1 = pd.read_csv("../../share-GermEval2025-data/Data/training data/task1.csv")
|
| 46 |
+
comments = comments.merge(task1, on=["document", "comment_id"])
|
| 47 |
+
|
| 48 |
+
# Remove duplicates
|
| 49 |
+
df = comments.drop_duplicates(subset=['comment', 'flausch'])
|
| 50 |
+
df.reset_index(drop=True, inplace=True)
|
| 51 |
+
|
| 52 |
+
# Use only a small subset for testing
|
| 53 |
+
if testing_mode:
|
| 54 |
+
os.environ["WANDB_MODE"] = "offline"
|
| 55 |
+
testing_mode_sample_size = 1000
|
| 56 |
+
df = df.sample(n=testing_mode_sample_size, random_state=42).reset_index(drop=True)
|
| 57 |
+
print(f"Testing mode: using only {testing_mode_sample_size} samples for quick testing.")
|
| 58 |
+
|
| 59 |
+
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
| 60 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
| 61 |
+
if left_padding:
|
| 62 |
+
return last_hidden_states[:, -1]
|
| 63 |
+
else:
|
| 64 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 65 |
+
batch_size = last_hidden_states.shape[0]
|
| 66 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
| 67 |
+
|
| 68 |
+
class Qwen3Embedder:
|
| 69 |
+
def __init__(self, model_name='Qwen/Qwen3-Embedding-8B', instruction=None, max_length=1024):
|
| 70 |
+
if instruction is None:
|
| 71 |
+
instruction = 'Classify a given comment as either flausch (a positive, supportive expression) or non-flausch.'
|
| 72 |
+
self.instruction = instruction
|
| 73 |
+
|
| 74 |
+
if is_flash_attn_2_available():
|
| 75 |
+
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16)
|
| 76 |
+
else:
|
| 77 |
+
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
|
| 78 |
+
|
| 79 |
+
self.model = self.model.cuda()
|
| 80 |
+
self.model.eval()
|
| 81 |
+
|
| 82 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left')
|
| 83 |
+
self.max_length = max_length
|
| 84 |
+
|
| 85 |
+
def get_detailed_instruct(self, query: str) -> str:
|
| 86 |
+
return f'Instruct: {self.instruction}\nQuery:{query}'
|
| 87 |
+
|
| 88 |
+
def encode_batch(self, texts, batch_size=32):
|
| 89 |
+
"""Encode texts in batches to handle memory efficiently"""
|
| 90 |
+
all_embeddings = []
|
| 91 |
+
|
| 92 |
+
for i in range(0, len(texts), batch_size):
|
| 93 |
+
batch_texts = [self.get_detailed_instruct(comment) for comment in texts[i:i + batch_size]]
|
| 94 |
+
|
| 95 |
+
# Tokenize batch
|
| 96 |
+
inputs = self.tokenizer(
|
| 97 |
+
batch_texts,
|
| 98 |
+
padding=True,
|
| 99 |
+
truncation=True,
|
| 100 |
+
max_length=self.max_length,
|
| 101 |
+
return_tensors='pt'
|
| 102 |
+
).to(device)
|
| 103 |
+
|
| 104 |
+
# Get embeddings
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
outputs = self.model(**inputs)
|
| 107 |
+
# Mean pooling
|
| 108 |
+
embeddings = last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
| 109 |
+
#embeddings = embeddings.float()
|
| 110 |
+
|
| 111 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
| 112 |
+
|
| 113 |
+
# Normalize embeddings (sollte ich?)
|
| 114 |
+
#import torch.nn.functional as F
|
| 115 |
+
#output = F.normalize(all_embeddings, p=2, dim=1)
|
| 116 |
+
return np.vstack(all_embeddings)
|
| 117 |
+
|
| 118 |
+
# Initialize embedder
|
| 119 |
+
print("Loading Qwen3 Embeddings v3...")
|
| 120 |
+
embedder = Qwen3Embedder(instruction='Classify a given comment as either flausch (a positive, supportive expression) or non-flausch')
|
| 121 |
+
|
| 122 |
+
X, y = df["comment"], df["flausch"].map(dict(yes=1, no=0))
|
| 123 |
+
|
| 124 |
+
# load embeddings if they exist
|
| 125 |
+
embeddings_file = f'{"testing_" if testing_mode else ""}Qwen3-Embedding-8B-{experiment_name}.npy'
|
| 126 |
+
if os.path.exists(embeddings_file):
|
| 127 |
+
print(f"Loading existing embeddings from {embeddings_file}")
|
| 128 |
+
X_embeddings = np.load(embeddings_file)
|
| 129 |
+
else:
|
| 130 |
+
print("Embeddings not found, generating new embeddings...")
|
| 131 |
+
# Encode texts in batches to avoid memory issues
|
| 132 |
+
X_embeddings = embedder.encode_batch(X.tolist(), batch_size=64)
|
| 133 |
+
print(f"Generated embeddings with shape: {X_embeddings.shape}")
|
| 134 |
+
|
| 135 |
+
# save embeddings to avoid recomputation
|
| 136 |
+
np.save(embeddings_file, X_embeddings)
|
| 137 |
+
|
| 138 |
+
wandb.init(
|
| 139 |
+
project=os.environ["WANDB_PROJECT"],
|
| 140 |
+
dir='./wandb_logs',
|
| 141 |
+
name=f"{experiment_name}",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# 5-fold stratified cross-validation
|
| 145 |
+
kf_splits = 5
|
| 146 |
+
|
| 147 |
+
pipe = Pipeline([
|
| 148 |
+
("scaler", StandardScaler()),
|
| 149 |
+
("svm", SVC(random_state=42, cache_size=2000))
|
| 150 |
+
])
|
| 151 |
+
|
| 152 |
+
param_grid = [
|
| 153 |
+
{
|
| 154 |
+
# Fitting 5 folds for each of 25 candidates, totalling 125 fits
|
| 155 |
+
'svm__kernel': ['rbf'],
|
| 156 |
+
'svm__C': [5, 6, 7, 8, 9, 10],
|
| 157 |
+
'svm__gamma': [0.00008, 0.0001, 0.0002, 1/4096, 0.0003, 0.0004, 0.0005, 0.0006]
|
| 158 |
+
# wähle diesen Bereich, da wir mit Qwen3-Embedding-8B 4096 Dimensionen haben
|
| 159 |
+
# und wir bei auto bei 1/4096 also ca. 2.4e-4 landen würden
|
| 160 |
+
},
|
| 161 |
+
# {
|
| 162 |
+
# 'kernel': ['poly'],
|
| 163 |
+
# 'C': [0.1, 1, 10, 100],
|
| 164 |
+
# 'degree': [2, 3, 4],
|
| 165 |
+
# 'gamma': ['scale', 'auto', 0.001, 0.01],
|
| 166 |
+
# 'coef0': [0.0, 0.1, 0.5, 1]
|
| 167 |
+
# }
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
f1_pos_scorer = make_scorer(f1_score, pos_label=1, average='binary')
|
| 172 |
+
|
| 173 |
+
X_train = X_embeddings
|
| 174 |
+
y_train = y
|
| 175 |
+
|
| 176 |
+
# 5‐fach StratifiedCV für die Grid‐Search
|
| 177 |
+
cv_inner = StratifiedKFold(n_splits=kf_splits, shuffle=True, random_state=42)
|
| 178 |
+
|
| 179 |
+
grid = GridSearchCV(
|
| 180 |
+
estimator=pipe,
|
| 181 |
+
param_grid=param_grid,
|
| 182 |
+
cv=cv_inner,
|
| 183 |
+
scoring=f1_pos_scorer,
|
| 184 |
+
n_jobs=63,
|
| 185 |
+
verbose=3,
|
| 186 |
+
return_train_score=True
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
grid.fit(X_train, y_train)
|
| 190 |
+
|
| 191 |
+
# 6. Ergebnisse ausgeben
|
| 192 |
+
print("Best F1 (pos) auf CV:", grid.best_score_)
|
| 193 |
+
print("Beste Parameter:", grid.best_params_)
|
| 194 |
+
print("Best estimator:", grid.best_estimator_)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
with open(f'scores.{experiment_name}.txt', 'a') as f:
|
| 198 |
+
f.write(f'[{time.strftime("%Y-%m-%d %H:%M:%S")}] {kf_splits}Fold CV\n')
|
| 199 |
+
f.write(f'[{experiment_name}] Best F1 (pos) auf CV: {grid.best_score_}\n')
|
| 200 |
+
f.write(f'[{experiment_name}] Beste Parameter: {grid.best_params_}\n')
|
| 201 |
+
f.write(f'[{experiment_name}] Best estimator: {grid.best_estimator_}\n')
|
| 202 |
+
|
| 203 |
+
results = pd.DataFrame(grid.cv_results_).sort_values("rank_test_score")
|
| 204 |
+
print("grid.cv_results_:")
|
| 205 |
+
print(results)
|
| 206 |
+
results.to_csv(f'grid_cv_results.{experiment_name}.csv', index=False)
|
| 207 |
+
|
| 208 |
+
with open(f"grid_cv.{experiment_name}.pkl", "wb") as f:
|
| 209 |
+
pickle.dump(grid, f)
|
| 210 |
+
|
| 211 |
+
print(f"GridSearchCV results saved to grid_cv_results.{experiment_name}.csv")
|
| 212 |
+
|
| 213 |
+
print(f"Training completed with {len(X_train)} samples...")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
print("Experiment completed!")
|
| 217 |
+
|
| 218 |
+
wandb.alert(
|
| 219 |
+
title=f'Experiment {experiment_name} finished!',
|
| 220 |
+
text=f'Best F1 (pos): {grid.best_score_:.4f}\nBest Params: {grid.best_params_}',
|
| 221 |
+
level=AlertLevel.INFO
|
| 222 |
+
)
|
| 223 |
+
wandb.finish()
|
| 224 |
+
print("Notification sent via Weights & Biases.")
|
subtask_1/grid_cv_results.exp019-2.csv
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_svm__C,param_svm__gamma,param_svm__kernel,params,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,mean_test_score,std_test_score,rank_test_score,split0_train_score,split1_train_score,split2_train_score,split3_train_score,split4_train_score,mean_train_score,std_train_score
|
| 2 |
+
2353.793872785568,376.68300732502587,668.3643176555634,245.2577496652275,10,0.0001,rbf,"{'svm__C': 10, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.905587668593449,0.8904684975767366,0.8992898644286637,0.8973607038123167,0.8943348185868873,0.8974083105996108,0.00506066727524964,1,0.9858907931446792,0.9876582530456246,0.9873357228195938,0.9868829000715478,0.985828025477707,0.9867191389118304,0.0007442000606650537
|
| 3 |
+
2135.3830691814424,220.07048372247345,630.6276105880737,168.76826649669474,100,0.0001,rbf,"{'svm__C': 100, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8957055214723927,0.8829236739974127,0.8949742268041238,0.8936866208701341,0.8826170622193714,0.889981421072687,0.0059239769597075765,2,0.9961636828644501,0.9964048893504833,0.9962481040951545,0.9965663179749261,0.9964031652146111,0.996357231899925,0.0001396196230179917
|
| 4 |
+
2107.9709944725037,201.0339803887574,590.9469698905945,148.89389411692406,100,1e-05,rbf,"{'svm__C': 100, 'svm__gamma': 1e-05, 'svm__kernel': 'rbf'}",0.8903143040410519,0.8850981654328934,0.8919093851132686,0.8930323846908734,0.8893141945773525,0.8899336867710879,0.002735551843186545,3,0.9536455818445195,0.9543955602026863,0.9526727404660162,0.9510804080649048,0.9523425530199178,0.9528273687196089,0.001134859016259991
|
| 5 |
+
2023.2640014648437,363.8169219637371,634.6866162300109,296.19671434247226,1,0.0001,rbf,"{'svm__C': 1, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8889618922470434,0.8828711256117455,0.8905785123966942,0.8863031914893617,0.8907672301690507,0.887896390382779,0.002978669586634141,4,0.9137704918032787,0.9158359209021082,0.9118032786885246,0.9141129361771676,0.9165093177900008,0.914406389072216,0.0016572588242453185
|
| 6 |
+
2190.105693435669,279.15396660484754,409.17747988700864,88.07139447254528,10,1e-05,rbf,"{'svm__C': 10, 'svm__gamma': 1e-05, 'svm__kernel': 'rbf'}",0.8841483426320972,0.8777777777777778,0.8815572418343781,0.8781624500665779,0.8868660598179454,0.8817023744257553,0.0034814075912781846,5,0.8997370151216305,0.9031255113729341,0.9008447469859756,0.9009422367882015,0.9013087496913326,0.9011916519920149,0.0011001857505116694
|
| 7 |
+
1812.7185018062592,487.2391337311374,417.01636271476747,166.52173260741165,100,1e-06,rbf,"{'svm__C': 100, 'svm__gamma': 1e-06, 'svm__kernel': 'rbf'}",0.8814449917898194,0.8751633986928105,0.8784676354029062,0.8762920973657886,0.8861418347430059,0.8795019915988661,0.003951188905961493,6,0.896869093598488,0.8999426370564615,0.8977850697292863,0.8984509466437177,0.8973240016467682,0.8980743497349442,0.0010706699096367878
|
| 8 |
+
1993.3692329406738,456.7097004392084,522.3544264793396,304.12100473275353,100,1e-07,rbf,"{'svm__C': 100, 'svm__gamma': 1e-07, 'svm__kernel': 'rbf'}",0.8652246256239601,0.8579842931937173,0.8614257161892072,0.8590559089387345,0.8706190632165084,0.8628619214324255,0.004606414952929956,7,0.866143034311699,0.8684057971014493,0.8666226477385276,0.8668761369274021,0.8654660137770769,0.866702725971231,0.0009777164172269123
|
| 9 |
+
2247.661126232147,581.4474017109956,708.4357552051545,236.8812027154758,10,1e-06,rbf,"{'svm__C': 10, 'svm__gamma': 1e-06, 'svm__kernel': 'rbf'}",0.8649367930805056,0.8577036310107949,0.8610463178940353,0.8590559089387345,0.8702490170380078,0.8625983335924156,0.004536892956190344,8,0.8661925239827986,0.8684057971014493,0.8667161838738962,0.8669698222405953,0.8654660137770769,0.8667500681951633,0.0009747337581279379
|
| 10 |
+
2407.543846988678,345.9172929856225,600.8789489269257,135.20490614298708,1,1e-05,rbf,"{'svm__C': 1, 'svm__gamma': 1e-05, 'svm__kernel': 'rbf'}",0.8651348651348651,0.8569558101472995,0.8611388611388612,0.8588669125041904,0.8698787282858079,0.8623950354422047,0.004628674161723668,9,0.8663853727144867,0.8689746562862349,0.8666721703954429,0.8674997930977406,0.8655427290092185,0.8670149443006248,0.0011624819670864098
|
| 11 |
+
3116.9194386959075,482.1913324221538,808.8905973434448,337.4328413836179,10,1e-07,rbf,"{'svm__C': 10, 'svm__gamma': 1e-07, 'svm__kernel': 'rbf'}",0.839123102866779,0.8371010638297872,0.835820895522388,0.8388851121685927,0.8458445040214477,0.8393549356817989,0.0034628992476326325,10,0.8390572390572391,0.8410299704516674,0.8408499202149996,0.839344262295082,0.8390630266262218,0.839868883729042,0.0008824914001438402
|
| 12 |
+
3009.1526700019836,742.3166871679517,681.0144076347351,150.1665691333959,100,1e-08,rbf,"{'svm__C': 100, 'svm__gamma': 1e-08, 'svm__kernel': 'rbf'}",0.839123102866779,0.8371010638297872,0.835820895522388,0.8388851121685927,0.8458445040214477,0.8393549356817989,0.0034628992476326325,10,0.8390572390572391,0.8409321175278622,0.8409472623446422,0.839344262295082,0.8390630266262218,0.8398687815702095,0.0008805415253987308
|
| 13 |
+
2988.9541951179503,539.1745917351178,573.9239889621734,154.38397714010605,1,1e-06,rbf,"{'svm__C': 1, 'svm__gamma': 1e-06, 'svm__kernel': 'rbf'}",0.839123102866779,0.8367143332224809,0.836104513064133,0.8391703502210133,0.8454575930271538,0.839313978480312,0.003312397500333861,12,0.8390572390572391,0.8408783783783784,0.8407696832198975,0.8393172454384933,0.8388401888064734,0.8397725469800964,0.0008723990422233897
|
| 14 |
+
4479.428878545761,394.52193518745116,530.3013439178467,238.64268098954136,10,1e-08,rbf,"{'svm__C': 10, 'svm__gamma': 1e-08, 'svm__kernel': 'rbf'}",0.7659099367324154,0.7609828741623231,0.7684839432412248,0.7571860816944024,0.7789240972733972,0.7662973866207526,0.007424619711579136,13,0.7651345291479821,0.7695741119583411,0.7656876456876457,0.7679003161614283,0.7625374251497006,0.7661668056210196,0.002411737416643407
|
| 15 |
+
4651.474856758117,292.18073753504444,416.76064705848694,180.45304034635336,1,1e-07,rbf,"{'svm__C': 1, 'svm__gamma': 1e-07, 'svm__kernel': 'rbf'}",0.7659099367324154,0.7609828741623231,0.7675635276532138,0.7571860816944024,0.7789240972733972,0.7661133035031504,0.007379397530916294,14,0.7651345291479821,0.7694167984373547,0.7653422868867749,0.7678571428571429,0.7626531948732341,0.7660807904404977,0.002344081336836005
|
| 16 |
+
4201.746336603164,491.5078042590254,626.0860621452332,239.76149256091043,1,1e-08,rbf,"{'svm__C': 1, 'svm__gamma': 1e-08, 'svm__kernel': 'rbf'}",0.0,0.0,0.0,0.0,0.0,0.0,0.0,15,0.0,0.0,0.0,0.0,0.0,0.0,0.0
|
subtask_1/grid_cv_results.exp019-3.csv
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_svm__C,param_svm__gamma,param_svm__kernel,params,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,mean_test_score,std_test_score,rank_test_score,split0_train_score,split1_train_score,split2_train_score,split3_train_score,split4_train_score,mean_train_score,std_train_score
|
| 2 |
+
5455.408059644699,398.6684175930872,1819.24121799469,702.8565515706526,5,0.0001,rbf,"{'svm__C': 5, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8922446890197082,0.9103590803409971,0.9039256198347108,0.8975026014568158,0.8993743482794577,0.900681267786338,0.0061183935076840005,1,0.9708551706506455,0.9694563616571684,0.9677872938770292,0.9689401216778738,0.9681381957773513,0.9690354287280136,0.0010826445588126672
|
| 3 |
+
6815.947726678848,669.4113202679961,1307.7390199661254,718.2998951073738,5,0.000244140625,rbf,"{'svm__C': 5, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8961272121056681,0.9077242094349404,0.9043793728945323,0.8957629321549259,0.8962091503267974,0.9000405753833727,0.005023063960942,2,0.9937460114869177,0.9923479148067849,0.9926672192820252,0.9922114402451481,0.9921511071405782,0.9926247385922908,0.0005883589375966587
|
| 4 |
+
5478.7035518169405,651.6782882254886,1624.6494281768798,581.1884804375871,10,0.0001,rbf,"{'svm__C': 10, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.892802450229709,0.9092311648238621,0.9037915914366779,0.8953367875647669,0.8985959438377535,0.8999515875785538,0.005917778255965958,3,0.9869931140015302,0.9857397504456328,0.9857560727457714,0.9850822389391815,0.9851469369541659,0.9857436226172563,0.0006862741243773296
|
| 5 |
+
5552.503013324737,436.61624342735956,1298.7537933349608,304.41444617181907,10,5e-05,rbf,"{'svm__C': 10, 'svm__gamma': 5e-05, 'svm__kernel': 'rbf'}",0.8918435182817693,0.9078233927188226,0.9030428055698814,0.8969286829776159,0.896551724137931,0.899238024737204,0.005575175271795643,4,0.9582905544147844,0.9566334725345326,0.9559974342527261,0.9574892276030613,0.9567956795679567,0.9570412736746123,0.0007845487956189628
|
| 6 |
+
6738.8976334095005,581.7043147683858,2363.1046784877776,1093.7550833856315,10,0.000244140625,rbf,"{'svm__C': 10, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8940754039497307,0.9058854031630801,0.9022556390977443,0.8963367108339828,0.895631702851164,0.8988369719791404,0.004484757587607084,5,0.9951474907419231,0.9944476354585488,0.9947664028593312,0.9948272558911808,0.99457111834962,0.9947519806601207,0.00023984985392036138
|
| 7 |
+
5479.927654600144,524.2887796010415,1723.1177164077758,872.0661863515872,1,0.000244140625,rbf,"{'svm__C': 1, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.892395240558717,0.9070010449320794,0.9032594524119948,0.8916929547844374,0.8938753959873285,0.8976448177349114,0.006265052008456617,6,0.9525471942073959,0.9506923773780251,0.9504745915929489,0.9507538989193037,0.9485536788973015,0.9506043481989949,0.0012670316908732473
|
| 8 |
+
7194.264643144607,434.09262123997644,2276.0355013370513,1194.8076020168799,50,0.000244140625,rbf,"{'svm__C': 50, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8937371663244353,0.9031420410283043,0.8979698073919833,0.8946135831381733,0.89501312335958,0.8968951442484953,0.003432170407915575,7,0.9961039790509038,0.9951456310679612,0.9960398569238631,0.9957185762668541,0.995463548655038,0.9956943183929241,0.0003583628407093519
|
| 9 |
+
4697.035187005997,526.5802792297949,989.723811674118,148.799116364752,5,5e-05,rbf,"{'svm__C': 5, 'svm__gamma': 5e-05, 'svm__kernel': 'rbf'}",0.890495867768595,0.9052686218531015,0.8997668997668997,0.892203035060178,0.8958825072121689,0.8967233863321885,0.005333041625787206,8,0.9384954033406707,0.9362170562714498,0.9359148112294289,0.9368972882014109,0.935051479634786,0.9365152077355493,0.0011540390227375986
|
| 10 |
+
7055.811586093902,962.5495699591098,2110.023814868927,705.5011720606061,100,0.000244140625,rbf,"{'svm__C': 100, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8933436134669751,0.9034267912772586,0.8952033368091762,0.89421573736321,0.8953030700603516,0.8962985097953944,0.003635163114263367,9,0.9962953500255494,0.9954005366040629,0.9964230965763924,0.9957827476038339,0.9955913360168679,0.9958986133653414,0.0003970912539776255
|
| 11 |
+
8686.581301164628,767.646570958569,2040.4449747562408,609.7467567837456,5,0.0005,rbf,"{'svm__C': 5, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8894668400520156,0.9034120734908136,0.9006831318970048,0.893740136770121,0.8925531914893617,0.8959710747398633,0.005225629609478866,10,0.9954676029364826,0.9948285769009768,0.9954664453100057,0.9950820719167146,0.9949543335249409,0.9951598061178242,0.0002633404117486148
|
| 12 |
+
9230.11362876892,405.54152137987444,1368.0186800956726,507.5534238465096,10,0.0005,rbf,"{'svm__C': 10, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8886576482830385,0.9028871391076115,0.8997104501184522,0.8933929981574098,0.8917265230114392,0.8952749517355901,0.005244823031429965,11,0.995977782034093,0.9951462511176395,0.9960398569238631,0.9955907725733274,0.9954641282821185,0.9956437581862083,0.0003319800601900671
|
| 13 |
+
8433.545366239548,764.8027335950326,2101.967122411728,348.04558910621176,1,0.0005,rbf,"{'svm__C': 1, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8870925684485007,0.9047993705743509,0.9014752370916754,0.8885941644562334,0.8912579957356077,0.8946438672612735,0.007139846007850025,12,0.9774985575998462,0.9765354532632389,0.9757940573770492,0.9755628247065615,0.9760348583877996,0.976285150266899,0.0006871071840619478
|
| 14 |
+
5641.2619892120365,301.3642532545948,1245.2856984615325,258.0022220586702,100,1e-05,rbf,"{'svm__C': 100, 'svm__gamma': 1e-05, 'svm__kernel': 'rbf'}",0.8904214559386974,0.9010819165378671,0.8979907264296755,0.8900414937759336,0.8935837245696401,0.8946238634503627,0.004307724355640959,13,0.9508512688724703,0.9488647327458674,0.9488413890493613,0.949270237253263,0.9493206259256874,0.94942965076933,0.0007380354112406489
|
| 15 |
+
8291.015897274017,703.4766614629165,1496.3943771839142,462.8280066865296,50,0.0005,rbf,"{'svm__C': 50, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8880228630813198,0.90154896298241,0.8979161171194935,0.893516078017923,0.8913738019169329,0.8944755646236159,0.0047759662781075955,14,0.9963594558344511,0.9954005366040629,0.9964226395809378,0.9958463799603808,0.9956555072834142,0.9959369038526494,0.00039738087640584543
|
| 16 |
+
6325.810354614257,330.9400455647892,1642.774357700348,688.5875334404985,50,0.0001,rbf,"{'svm__C': 50, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.889171974522293,0.9000258331180574,0.8982655966865131,0.8940874035989718,0.8903966597077244,0.8943894935267119,0.0042437951054770245,15,0.9954028859660324,0.9945745835194996,0.9951493489915751,0.9949543335249409,0.9946995338144198,0.9949561371632936,0.0002994484991425243
|
| 17 |
+
5036.9365752696995,632.7213571409109,1517.7294404506683,426.3296032210096,50,1e-05,rbf,"{'svm__C': 50, 'svm__gamma': 1e-05, 'svm__kernel': 'rbf'}",0.8891170431211499,0.9033428349313294,0.8969258589511754,0.8912591050988553,0.8910994764397906,0.89434886370846,0.005196335158243723,16,0.9354442649434572,0.9340737392651901,0.9350515463917526,0.9344272885845338,0.9342130797593635,0.9346419837888595,0.0005223698139574696
|
| 18 |
+
8931.353061866761,1234.4122853976255,2149.3760446071624,1033.6010444948738,100,0.0005,rbf,"{'svm__C': 100, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8868071818891491,0.9014972419227738,0.8973885518332894,0.8928759894459103,0.8914316125598722,0.8940001155301989,0.00504554389640692,17,0.9964862965565706,0.9954647077610987,0.9964226395809378,0.9958463799603808,0.9956555072834142,0.9959751062284804,0.0004100785745162828
|
| 19 |
+
5712.9752474784855,422.97193138513984,1671.5140540599823,728.145449370879,50,5e-05,rbf,"{'svm__C': 50, 'svm__gamma': 5e-05, 'svm__kernel': 'rbf'}",0.8886069525501142,0.9009240246406571,0.8988937483920761,0.8875160875160876,0.8914285714285715,0.8934738769055013,0.005445093016473903,18,0.9919041244342449,0.9908888180949347,0.9909565660425423,0.9908151549942594,0.990498054970984,0.991012543707393,0.00047265604462407793
|
| 20 |
+
4886.372617149353,497.6637188714639,1340.4910155296325,512.4187982291547,1,0.0001,rbf,"{'svm__C': 1, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8866943866943867,0.9011277209546289,0.8970934799685781,0.8871903004744334,0.8881248346998148,0.8920461445583683,0.005925332557641618,19,0.9178055319427189,0.9155316919853327,0.9151811949069539,0.9175931981687377,0.9152409559879848,0.9162705145983455,0.0011745839476710812
|
| 21 |
+
5400.760613489151,277.7619437511136,1815.8284684658051,495.1351854762197,100,0.0001,rbf,"{'svm__C': 100, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8901910828025478,0.8979064357715172,0.890272373540856,0.8911196911196911,0.8878431372549019,0.8914665440979028,0.0033992479680527554,20,0.9960403627538639,0.9950820719167146,0.9959767545820295,0.9955282994761723,0.9954647077610987,0.9956184392979758,0.0003537959259237909
|
| 22 |
+
5831.041041278839,500.2366199022952,1163.3661043167115,651.9972938828764,100,5e-05,rbf,"{'svm__C': 100, 'svm__gamma': 5e-05, 'svm__kernel': 'rbf'}",0.8865194211728865,0.8980957282552754,0.8943298969072165,0.885523613963039,0.88518614944025,0.8899309619477336,0.00528376021564831,21,0.9947650663942799,0.9939378469784953,0.9945118059987237,0.9943170934167678,0.9939982122334312,0.9943060050043396,0.00031096336881715784
|
| 23 |
+
5106.945345258713,501.77199871724264,1510.7054524421692,525.6496772631232,10,1e-05,rbf,"{'svm__C': 10, 'svm__gamma': 1e-05, 'svm__kernel': 'rbf'}",0.8815584415584415,0.895577074064381,0.8924928066963118,0.8827404479578392,0.8807193864057128,0.8866176313365373,0.006167763694053975,22,0.9048678356451191,0.9032511798636602,0.9037820889672742,0.9035087719298246,0.9037405179178656,0.9038300788647486,0.0005523813046370195
|
| 24 |
+
5257.808340215683,524.2623404859893,1091.6346819400787,557.959192374461,1,5e-05,rbf,"{'svm__C': 1, 'svm__gamma': 5e-05, 'svm__kernel': 'rbf'}",0.8762402088772846,0.89304531085353,0.8867825171142707,0.8729693741677763,0.8740978348035284,0.880627049163278,0.007907186646214762,23,0.8958237118163225,0.8928783578641674,0.893843725335438,0.8947437755236465,0.8950138139718458,0.894460676902284,0.0010127519339155767
|
| 25 |
+
4648.525324726104,577.1027936552093,1347.1555349826813,729.69553292323,5,1e-05,rbf,"{'svm__C': 5, 'svm__gamma': 1e-05, 'svm__kernel': 'rbf'}",0.8768909754825248,0.8896497234658941,0.8837453971593898,0.8710191082802548,0.8731642189586115,0.878893884669335,0.006900406643147761,24,0.8914851485148515,0.8886542142432843,0.8901597318083219,0.8914504017915953,0.8917691953417988,0.8907037383399704,0.001166173136675557
|
| 26 |
+
5639.493494796753,385.0617052361121,1839.803660440445,683.8481740973009,1,1e-05,rbf,"{'svm__C': 1, 'svm__gamma': 1e-05, 'svm__kernel': 'rbf'}",0.8623949579831933,0.87595874107379,0.8709677419354839,0.8563815614175326,0.8553763440860215,0.8642158692992042,0.008078516269848037,25,0.8684158087014281,0.863926576217079,0.8676792465344565,0.8695018547959724,0.8706286771997092,0.8680304326897289,0.0022816415968571016
|
subtask_1/grid_cv_results.exp019-4.csv
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_svm__C,param_svm__gamma,param_svm__kernel,params,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,mean_test_score,std_test_score,rank_test_score,split0_train_score,split1_train_score,split2_train_score,split3_train_score,split4_train_score,mean_train_score,std_train_score
|
| 2 |
+
2962.254550600052,556.4586160586341,952.6687445163727,88.22149099477036,5,0.0002,rbf,"{'svm__C': 5, 'svm__gamma': 0.0002, 'svm__kernel': 'rbf'}",0.8956142600666838,0.9105480868665977,0.9049095607235143,0.8974292391586601,0.8989292243405589,0.901486074231203,0.005501162177555585,1,0.9913298482723447,0.9899337410805301,0.9902541563156889,0.9903680551125853,0.9896012759170654,0.9902974153396429,0.0005813171498605362
|
| 3 |
+
2907.5319063186644,622.9412053502689,893.9788520812988,156.92165447691596,6,0.0001,rbf,"{'svm__C': 6, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8931492842535788,0.9111053852099974,0.903359173126615,0.89893478825669,0.9000520562207184,0.90132013741352,0.005898635485442219,2,0.9759805800434393,0.9748242811501597,0.9740516416958878,0.973975318114969,0.9735463258785942,0.97447562937661,0.0008579179998278246
|
| 4 |
+
2294.846938943863,497.95159046395213,750.493603515625,166.74923962097026,5,0.0001,rbf,"{'svm__C': 5, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8922446890197082,0.9103590803409971,0.9039256198347108,0.8975026014568158,0.8993743482794577,0.900681267786338,0.0061183935076840005,3,0.9708551706506455,0.9694563616571684,0.9677872938770292,0.9689401216778738,0.9681381957773513,0.9690354287280136,0.0010826445588126672
|
| 5 |
+
3643.660668420792,527.2542232689798,969.6882706165313,72.65267607425072,6,0.0002,rbf,"{'svm__C': 6, 'svm__gamma': 0.0002, 'svm__kernel': 'rbf'}",0.8967989756722151,0.9092788834324115,0.9033092037228542,0.8959792477302205,0.8967859942513718,0.9004304609618146,0.005154897910679173,4,0.9929829038019903,0.9916491362274494,0.9919683834778175,0.9918927545483562,0.991261083115392,0.991950852234201,0.0005719487931925857
|
| 6 |
+
2822.824296474457,146.78322313502832,943.8164008617401,103.05704858326389,7,0.0001,rbf,"{'svm__C': 7, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8936061381074168,0.9099330931549151,0.9022440030951767,0.8980544747081712,0.8982035928143712,0.9004082603760102,0.005491122365808926,5,0.9795709908069459,0.9783732057416268,0.9782331975560081,0.9780332056194125,0.9779054916985952,0.9784232182845176,0.0005960103340822639
|
| 7 |
+
3752.820015335083,675.8436439014653,1054.1311081886292,142.1847999664584,5,0.000244140625,rbf,"{'svm__C': 5, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8961272121056681,0.9077242094349404,0.9043793728945323,0.8957629321549259,0.8962091503267974,0.9000405753833727,0.005023063960942,6,0.9937460114869177,0.9923479148067849,0.9926672192820252,0.9922114402451481,0.9921511071405782,0.9926247385922908,0.0005883589375966587
|
| 8 |
+
2866.750690841675,62.310380211696746,1010.1708914756775,72.51314774382548,8,0.0001,rbf,"{'svm__C': 8, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8921819110884006,0.9099794238683128,0.9029927760577915,0.8963730569948186,0.8984903695991671,0.9000035075216981,0.006084435249086387,7,0.9825948358304112,0.9816466989548814,0.9812881873727087,0.9812428225086129,0.9811055789608069,0.9815756247254843,0.0005400302566906774
|
| 9 |
+
2726.576060628891,1076.4445437635895,767.500729751587,324.01429729941805,10,0.0001,rbf,"{'svm__C': 10, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.892802450229709,0.9092311648238621,0.9037915914366779,0.8953367875647669,0.8985959438377535,0.8999515875785538,0.005917778255965958,8,0.9869931140015302,0.9857397504456328,0.9857560727457714,0.9850822389391815,0.9851469369541659,0.9857436226172563,0.0006862741243773296
|
| 10 |
+
2161.447349357605,573.8507236892065,765.411279296875,242.26387746003647,6,8e-05,rbf,"{'svm__C': 6, 'svm__gamma': 8e-05, 'svm__kernel': 'rbf'}",0.8908207619534646,0.9096074380165289,0.9025270758122743,0.8983315954118873,0.8982785602503912,0.8999130862889093,0.0061425816597368895,9,0.9661136378194862,0.9648627853295717,0.9627400768245838,0.9644253573488879,0.9637039887136079,0.9643691692072274,0.0011301674037890955
|
| 11 |
+
2511.2280893802645,195.0702765837904,851.77890791893,159.33818171498493,8,8e-05,rbf,"{'svm__C': 8, 'svm__gamma': 8e-05, 'svm__kernel': 'rbf'}",0.8913876820853565,0.9091377091377091,0.9012122775341759,0.8978955572876072,0.8995837669094693,0.8998433985908635,0.005720466311006615,10,0.973877498882289,0.9732217038409919,0.9724454649827784,0.97250990921877,0.9717940518068436,0.9727697257463346,0.000714987388108625
|
| 12 |
+
3198.1750497817993,566.8190693243176,1058.364576435089,61.431693130672315,6,0.000244140625,rbf,"{'svm__C': 6, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8962506420133539,0.9072058061171592,0.9038611039129308,0.8949557982319293,0.8966780015694481,0.8997902703689643,0.004840437232223766,11,0.9940012763241863,0.9933014354066986,0.9936842105263158,0.9934239928493903,0.9932311621966794,0.9935284154606542,0.00028236608448091314
|
| 13 |
+
3066.5988575458528,147.7978752019348,831.5851910114288,63.516287946945454,9,0.0001,rbf,"{'svm__C': 9, 'svm__gamma': 0.0001, 'svm__kernel': 'rbf'}",0.8921819110884006,0.9088568486096807,0.903975219411461,0.8948186528497409,0.8988822459059007,0.8997429755730367,0.006052009180756879,12,0.9850803366488141,0.9839490445859873,0.9834626637832337,0.9832259710440717,0.9836086485107468,0.9838653329145707,0.0006511612138484461
|
| 14 |
+
2721.2118691921232,273.9275976646404,1010.9925980567932,33.626638488332404,7,8e-05,rbf,"{'svm__C': 7, 'svm__gamma': 8e-05, 'svm__kernel': 'rbf'}",0.891332140117617,0.9085287297088379,0.9025270758122743,0.8983623602807382,0.897342365815529,0.8996185343469992,0.005714445566010723,13,0.9709135076391996,0.968944099378882,0.9676553311173612,0.968633977723723,0.9677336747759283,0.9687761181270188,0.0011798056774958321
|
| 15 |
+
3768.9346249580385,784.9161006658384,1195.0321260929109,90.27548545731533,5,0.0003,rbf,"{'svm__C': 5, 'svm__gamma': 0.0003, 'svm__kernel': 'rbf'}",0.8932338564445588,0.9079563182527302,0.9044155844155845,0.8965876530346445,0.8954128440366973,0.8995212512368429,0.0056589962755815805,14,0.9941923543302061,0.9936200076559908,0.9939378469784953,0.9938705146213765,0.9935492112154308,0.9938339869602999,0.0002312715774820335
|
| 16 |
+
2439.5963623046873,788.1385495478739,745.0145160675049,163.93709476186407,5,8e-05,rbf,"{'svm__C': 5, 'svm__gamma': 8e-05, 'svm__kernel': 'rbf'}",0.8918503331624807,0.9092319627618308,0.9031758326878389,0.895903991651448,0.897288842544317,0.899490192561583,0.006075528952760006,15,0.9595427690726945,0.9579475308641975,0.9568119104151961,0.9581778406897439,0.9577120822622108,0.9580384266608087,0.0008833557807307179
|
| 17 |
+
2662.4520683288574,110.2790439021587,789.4306183815003,35.48304290980892,9,8e-05,rbf,"{'svm__C': 9, 'svm__gamma': 8e-05, 'svm__kernel': 'rbf'}",0.8912155260469867,0.9084362139917695,0.9014447884416925,0.8966770508826584,0.8995314940135346,0.8994610146753285,0.00566097280654277,16,0.9770709586766303,0.9758121130895399,0.975796178343949,0.9754569858110699,0.9752826211917992,0.9758837714225976,0.0006270165226622817
|
| 18 |
+
3648.960864543915,158.20137698659786,1278.5936987876892,172.71756799035126,8,0.0003,rbf,"{'svm__C': 8, 'svm__gamma': 0.0003, 'svm__kernel': 'rbf'}",0.8933436134669751,0.9070554543087738,0.9040312093628089,0.8966414996094767,0.8953030700603516,0.8992749693616773,0.005311306707498723,17,0.9951474907419231,0.9944476354585488,0.9949569103096074,0.9948272558911808,0.99457111834962,0.9947900821501762,0.00025383825787875344
|
| 19 |
+
3668.4558837890627,330.5950233154227,1209.3678759098052,193.34476642369032,6,0.0003,rbf,"{'svm__C': 6, 'svm__gamma': 0.0003, 'svm__kernel': 'rbf'}",0.8930041152263375,0.9070554543087738,0.9038961038961039,0.8970358814352574,0.8953030700603516,0.8992589249853647,0.005329214832900393,18,0.9946380697050938,0.9940654712526322,0.9943185445260134,0.9943805874840358,0.9941229078829692,0.9943051161701488,0.00020364863996692793
|
| 20 |
+
3648.1113197803497,250.88385803011812,1053.2344131469727,110.24198198721432,7,0.0003,rbf,"{'svm__C': 7, 'svm__gamma': 0.0003, 'svm__kernel': 'rbf'}",0.8932887631781949,0.9067708333333333,0.9041309431021044,0.8969286829776159,0.8950682056663168,0.8992374856515131,0.00526873832264323,19,0.9950833280122597,0.9943207198009061,0.9947029165868914,0.9947002107145138,0.9944433799578464,0.9946501110144835,0.000262378229639201
|
| 21 |
+
3336.6828705310822,182.4338639019245,984.4061690330506,65.03067948107055,7,0.0002,rbf,"{'svm__C': 7, 'svm__gamma': 0.0002, 'svm__kernel': 'rbf'}",0.8957212400717397,0.9084798345398138,0.9010362694300518,0.8948871009602907,0.8958660387231816,0.8991980967450155,0.0051245558432841155,20,0.993746809596733,0.992601097078709,0.9931113662456946,0.9925311203319502,0.9926597306440289,0.992930024779423,0.00045615770854314354
|
| 22 |
+
2673.180758523941,1032.395217646984,545.5460843086242,130.47572525732375,10,8e-05,rbf,"{'svm__C': 10, 'svm__gamma': 8e-05, 'svm__kernel': 'rbf'}",0.8913265306122449,0.9084362139917695,0.9013649240278135,0.8966053381705105,0.8982565703877179,0.8991979154380111,0.005649926952866353,21,0.9796361315033514,0.9784328739152629,0.9783065080475857,0.9780332056194125,0.977840219681972,0.978449787753517,0.0006282261620036984
|
| 23 |
+
3687.9748188495637,226.27635714477285,1235.9286698818207,96.43253951529712,7,0.000244140625,rbf,"{'svm__C': 7, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8953309389430477,0.9064039408866995,0.9020725388601036,0.894723160904601,0.8967320261437909,0.8990525211476486,0.004497601815320273,22,0.9943827396910507,0.9938118022328548,0.9941289087428207,0.9939339761190218,0.9937412185464299,0.9939997290664356,0.0002323153723652084
|
| 24 |
+
3475.950292634964,250.68307679317394,1108.8572846889497,173.86681006560963,8,0.000244140625,rbf,"{'svm__C': 8, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8942505133470225,0.9061203319502075,0.9024896265560166,0.8958712022851207,0.8959205020920502,0.8989304352460834,0.004575647064168655,23,0.9947650663942799,0.9941289087428207,0.9943185445260134,0.9943805874840358,0.994187160651549,0.9943560535597398,0.00022334722839619742
|
| 25 |
+
4691.708357810974,715.7987333362521,856.614222574234,169.3726110893645,9,0.0003,rbf,"{'svm__C': 9, 'svm__gamma': 0.0003, 'svm__kernel': 'rbf'}",0.8923156001028013,0.9057291666666667,0.9040312093628089,0.8972691807542262,0.8950682056663168,0.8988826725105641,0.005170335906126429,24,0.9954670241971525,0.9946380697050938,0.9952116452786822,0.9949543335249409,0.9947630604164006,0.9950068266244539,0.00030062988617190713
|
| 26 |
+
3234.2388828754424,1459.5692858954721,924.4178524971009,524.3748061613071,10,0.000244140625,rbf,"{'svm__C': 10, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8940754039497307,0.9058854031630801,0.9022556390977443,0.8963367108339828,0.895631702851164,0.8988369719791404,0.004484757587607084,25,0.9951474907419231,0.9944476354585488,0.9947664028593312,0.9948272558911808,0.99457111834962,0.9947519806601207,0.00023984985392036138
|
| 27 |
+
3163.7975586414336,1065.2188857087929,820.229798078537,589.5385645528229,10,0.0003,rbf,"{'svm__C': 10, 'svm__gamma': 0.0003, 'svm__kernel': 'rbf'}",0.8925449871465295,0.9052083333333333,0.9034608378870674,0.8975559022360895,0.8953579858379229,0.8988256092881886,0.004802234756394428,26,0.9954676029364826,0.9948285769009768,0.9954664453100057,0.995017884517118,0.9948265951331673,0.99512142095955,0.00029061433860554813
|
| 28 |
+
3278.97476811409,170.465214883947,1106.7234085559844,51.65991928293311,8,0.0002,rbf,"{'svm__C': 8, 'svm__gamma': 0.0002, 'svm__kernel': 'rbf'}",0.8955453149001537,0.9073498964803313,0.9007514900233221,0.8937938197870683,0.8958660387231816,0.8986613119828114,0.00492089047840589,27,0.9939378469784953,0.9933014354066986,0.9935574408368948,0.9934239928493903,0.9932311621966794,0.9934903756536316,0.00024980994185843
|
| 29 |
+
3841.6124336719513,530.3719300313371,717.6561746120453,451.3315178843931,9,0.000244140625,rbf,"{'svm__C': 9, 'svm__gamma': 0.000244140625, 'svm__kernel': 'rbf'}",0.8935078265332307,0.9063553826199741,0.9022049286640726,0.8952974798648999,0.8953427524856097,0.8985416740335573,0.004910620976859314,28,0.994956266360212,0.9942572741194488,0.99457527602272,0.9946988567413936,0.9943791517629024,0.9945733650013354,0.0002449479565334618
|
| 30 |
+
3785.3250827789307,157.957270731452,1000.0055857658386,413.19894731125567,9,0.0002,rbf,"{'svm__C': 9, 'svm__gamma': 0.0002, 'svm__kernel': 'rbf'}",0.8950332821300563,0.90641158221303,0.9001297016861219,0.8930980799169694,0.8964976476738108,0.8982340587239976,0.004692626665297457,29,0.9943827396910507,0.9936834045811268,0.9940020418580908,0.9938705146213765,0.9937412185464299,0.993935983859615,0.0002489861331004921
|
| 31 |
+
1625.4079483509063,415.4024195664831,842.2647498607636,533.5642210948797,10,0.0002,rbf,"{'svm__C': 10, 'svm__gamma': 0.0002, 'svm__kernel': 'rbf'}",0.894413121476166,0.9059431524547804,0.9004665629860031,0.893959035519834,0.8962633916906193,0.8982090528254805,0.004498221612137673,30,0.994573890839451,0.9939386205576469,0.9941916129444054,0.9941893876508524,0.9939951450108598,0.9941777314006431,0.0002226129606986996
|
| 32 |
+
4175.494106626511,240.66795475761913,1410.2359414100647,135.26701239594428,6,0.0004,rbf,"{'svm__C': 6, 'svm__gamma': 0.0004, 'svm__kernel': 'rbf'}",0.8919896640826873,0.9042386185243328,0.9044752682543836,0.89527291721076,0.8942917547568711,0.8980536445658069,0.005256372751673358,31,0.9953393347379174,0.9947650663942799,0.9951481103166496,0.9948907906501469,0.9947630604164006,0.9949812725030789,0.00022741679006577428
|
| 33 |
+
4326.916931676865,602.9984138032598,1227.9458584785461,197.47256597026075,5,0.0004,rbf,"{'svm__C': 5, 'svm__gamma': 0.0004, 'svm__kernel': 'rbf'}",0.8917592353397055,0.9055715406748627,0.9037153322867608,0.89527291721076,0.8935270805812418,0.8979692212186661,0.005592488656824322,32,0.9950839558194471,0.9943199948943774,0.9948940515700792,0.9947637292464878,0.9946346448645886,0.9947392752789961,0.0002569443271676517
|
| 34 |
+
4182.006548070907,118.61878371306456,1325.9245444774629,128.4962088893567,7,0.0004,rbf,"{'svm__C': 7, 'svm__gamma': 0.0004, 'svm__kernel': 'rbf'}",0.8915289256198347,0.9040020925974366,0.9027522935779817,0.8955067920585162,0.8942917547568711,0.8976163717221279,0.004893169295618382,33,0.9955305835780871,0.9948285769009768,0.9955300127713921,0.9951456310679612,0.9948907906501469,0.9951851189937129,0.0003012003422766297
|
| 35 |
+
4359.134440898895,349.4379644221477,1247.1464223861694,48.51511167190543,8,0.0004,rbf,"{'svm__C': 8, 'svm__gamma': 0.0004, 'svm__kernel': 'rbf'}",0.8907259106174116,0.9040020925974366,0.9022280471821756,0.8952181865691142,0.8940554821664465,0.8972459438265169,0.005045167553380767,34,0.9956582811901418,0.9948285769009768,0.9955300127713921,0.9952733776188043,0.9952091983391887,0.9952998893641007,0.000287215826893503
|
| 36 |
+
3258.862170982361,983.5605167556715,587.3299376487732,331.17491359144447,10,0.0004,rbf,"{'svm__C': 10, 'svm__gamma': 0.0004, 'svm__kernel': 'rbf'}",0.8907259106174116,0.9040020925974366,0.9021767637031209,0.894929430214323,0.8934707903780069,0.8970609975020599,0.0051364648753549295,35,0.9958500925748579,0.9950833280122597,0.9957849022863712,0.9954647077610987,0.995399948888321,0.9955165959045816,0.0002782906092371906
|
| 37 |
+
4844.620493507386,626.0919379166494,911.782252407074,323.6546317691595,9,0.0004,rbf,"{'svm__C': 9, 'svm__gamma': 0.0004, 'svm__kernel': 'rbf'}",0.8909560723514212,0.9037656903765691,0.9018887722980063,0.8946955840083617,0.8934707903780069,0.896955381882473,0.0049791449468967654,36,0.9957224031156228,0.9950197931298684,0.9956577266922094,0.9952733776188043,0.9952091983391887,0.9953764997791387,0.00027003750858763207
|
| 38 |
+
4893.450245141983,756.5570053534453,1404.103228712082,151.04221061712673,5,0.0005,rbf,"{'svm__C': 5, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8894668400520156,0.9034120734908136,0.9006831318970048,0.893740136770121,0.8925531914893617,0.8959710747398633,0.005225629609478866,37,0.9954676029364826,0.9948285769009768,0.9954664453100057,0.9950820719167146,0.9949543335249409,0.9951598061178242,0.0002633404117486148
|
| 39 |
+
4990.464423799514,170.82569174595943,1480.2943919181823,186.67598755544077,6,0.0005,rbf,"{'svm__C': 6, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8890044190278139,0.9034120734908136,0.9009198423127464,0.8936842105263157,0.8920212765957447,0.8958083643906868,0.005460610492136122,38,0.9957224031156228,0.994956266360212,0.9955300127713921,0.9953369530501437,0.9952091983391887,0.9953509667273119,0.00026309903658882153
|
| 40 |
+
5007.180242681503,231.11542984640803,1395.9259331703186,204.12036576004334,8,0.0005,rbf,"{'svm__C': 8, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8889466840052016,0.9031241795746915,0.9004739336492891,0.8944459068175836,0.8914893617021277,0.8956960131497788,0.005344498234287286,39,0.995977782034093,0.9950833280122597,0.9958485022673564,0.9955277280858676,0.9954641282821185,0.995580293736339,0.0003140589926224988
|
| 41 |
+
4690.542550802231,156.2481940371033,1510.894291639328,161.45358965306164,7,0.0005,rbf,"{'svm__C': 7, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8890044190278139,0.9031241795746915,0.900815574848724,0.8936842105263157,0.8917265230114392,0.8956709813977968,0.005403074479030133,40,0.9957229492499202,0.9950833280122597,0.9956577266922094,0.9954647077610987,0.9953363572478119,0.99545301379266,0.00023029789270078538
|
| 42 |
+
3863.1712747097017,1272.979608552642,731.4842251300812,347.05442064609605,9,0.0005,rbf,"{'svm__C': 9, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8889466840052016,0.9028871391076115,0.9,0.8944459068175836,0.8917265230114392,0.895601250588367,0.005158774221364523,41,0.995977782034093,0.9950833280122597,0.9959126325201175,0.9955271565495207,0.9954641282821185,0.9955930054796219,0.0003258622335451535
|
| 43 |
+
2941.8006259441377,963.9592947713663,792.06405377388,413.3915357813076,10,0.0005,rbf,"{'svm__C': 10, 'svm__gamma': 0.0005, 'svm__kernel': 'rbf'}",0.8886576482830385,0.9028871391076115,0.8997104501184522,0.8933929981574098,0.8917265230114392,0.8952749517355901,0.005244823031429965,42,0.995977782034093,0.9951462511176395,0.9960398569238631,0.9955907725733274,0.9954641282821185,0.9956437581862083,0.0003319800601900671
|
| 44 |
+
5125.378367471695,385.2361899541604,1681.7096691131592,317.80513982062877,5,0.0006,rbf,"{'svm__C': 5, 'svm__gamma': 0.0006, 'svm__kernel': 'rbf'}",0.8873165618448637,0.899736147757256,0.8978835978835978,0.8901273885350318,0.8895442359249329,0.8929215863891364,0.0049333587475242895,43,0.9957224031156228,0.9950197931298684,0.9955300127713921,0.9953357612932081,0.9952721696907744,0.9953760280001731,0.00023790681760390187
|
| 45 |
+
5443.34237332344,248.4197468645117,1615.5747085094451,101.54156492988628,6,0.0006,rbf,"{'svm__C': 6, 'svm__gamma': 0.0006, 'svm__kernel': 'rbf'}",0.8867330886208705,0.9000263782643102,0.8973001588141875,0.8903054448871182,0.8892464467685707,0.8927223034710113,0.005061615312935921,44,0.995977782034093,0.9950833280122597,0.9957849022863712,0.9955265848670757,0.9953993610223643,0.9955543916444327,0.0003095491025041155
|
| 46 |
+
4886.9807908058165,1271.2093796793506,935.9216484069824,248.52890337673645,9,0.0006,rbf,"{'svm__C': 9, 'svm__gamma': 0.0006, 'svm__kernel': 'rbf'}",0.8862683438155137,0.9000263782643102,0.8967161016949152,0.8906001062134891,0.8897827835880934,0.8926787427152643,0.00498173742220226,45,0.9960408684546616,0.9951462511176395,0.9960398569238631,0.995590208985748,0.995463548655038,0.99565614682739,0.000345450331712534
|
| 47 |
+
5317.718171501159,69.30253901488373,1633.7499773979187,123.257398427158,7,0.0006,rbf,"{'svm__C': 7, 'svm__gamma': 0.0006, 'svm__kernel': 'rbf'}",0.8865006553079947,0.9000263782643102,0.8964786867884564,0.8906001062134891,0.8895442359249329,0.8926300124998366,0.004913859511392688,46,0.9960413740263057,0.99514687100894,0.9959126325201175,0.9955265848670757,0.995463548655038,0.9956182022154954,0.00032256104982186627
|
| 48 |
+
5630.75842347145,173.39260110695608,1854.6558534622193,343.7979980330878,8,0.0006,rbf,"{'svm__C': 8, 'svm__gamma': 0.0006, 'svm__kernel': 'rbf'}",0.8862683438155137,0.9000263782643102,0.8964786867884564,0.8903054448871182,0.8895442359249329,0.8925246179360663,0.004997188016010959,47,0.9959772683736671,0.9950827000447027,0.9960398569238631,0.995590208985748,0.995463548655038,0.9956307165966038,0.00035139780020874577
|
| 49 |
+
3351.3005242347717,605.7478038761916,610.0130533695221,370.216124260849,10,0.0006,rbf,"{'svm__C': 10, 'svm__gamma': 0.0006, 'svm__kernel': 'rbf'}",0.8862683438155137,0.9000263782643102,0.8964238410596026,0.8897156524049961,0.8897827835880934,0.892443399826503,0.005019885972826287,48,0.9960408684546616,0.9951462511176395,0.9960398569238631,0.9957180290151467,0.995463548655038,0.9956817108332698,0.00034435263106865976
|
subtask_1/submission_subtask1-2.ipynb
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "d10bfa50537af75f",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"## Experiment exp027-2\n",
|
| 9 |
+
"xlm-roberta-large, Batch Size: 32, Learning Rate: 2e-5, Warmup Steps: 500"
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "code",
|
| 14 |
+
"execution_count": 51,
|
| 15 |
+
"id": "9748a35a024779ae",
|
| 16 |
+
"metadata": {
|
| 17 |
+
"ExecuteTime": {
|
| 18 |
+
"end_time": "2025-06-27T22:06:52.194727Z",
|
| 19 |
+
"start_time": "2025-06-27T22:06:52.191088Z"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"outputs": [],
|
| 23 |
+
"source": [
|
| 24 |
+
"import pandas as pd\n",
|
| 25 |
+
"import numpy as np\n",
|
| 26 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 27 |
+
"from transformers import (\n",
|
| 28 |
+
" AutoTokenizer,\n",
|
| 29 |
+
" BertForTokenClassification,\n",
|
| 30 |
+
" AutoModelForTokenClassification\n",
|
| 31 |
+
")\n",
|
| 32 |
+
"import torch\n",
|
| 33 |
+
"import os\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
|
| 36 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"execution_count": 56,
|
| 42 |
+
"id": "4ae3d9e4c556a288",
|
| 43 |
+
"metadata": {
|
| 44 |
+
"ExecuteTime": {
|
| 45 |
+
"end_time": "2025-06-27T22:07:26.334867Z",
|
| 46 |
+
"start_time": "2025-06-27T22:07:26.325629Z"
|
| 47 |
+
}
|
| 48 |
+
},
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"source": [
|
| 51 |
+
"test_comments_spans = pd.read_csv(\"./submissions/task2-predicted.csv\")"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"execution_count": 57,
|
| 57 |
+
"id": "156c9b1c48a954b4",
|
| 58 |
+
"metadata": {
|
| 59 |
+
"ExecuteTime": {
|
| 60 |
+
"end_time": "2025-06-27T22:07:30.302897Z",
|
| 61 |
+
"start_time": "2025-06-27T22:07:30.290021Z"
|
| 62 |
+
}
|
| 63 |
+
},
|
| 64 |
+
"outputs": [
|
| 65 |
+
{
|
| 66 |
+
"data": {
|
| 67 |
+
"text/html": [
|
| 68 |
+
"<div>\n",
|
| 69 |
+
"<style scoped>\n",
|
| 70 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 71 |
+
" vertical-align: middle;\n",
|
| 72 |
+
" }\n",
|
| 73 |
+
"\n",
|
| 74 |
+
" .dataframe tbody tr th {\n",
|
| 75 |
+
" vertical-align: top;\n",
|
| 76 |
+
" }\n",
|
| 77 |
+
"\n",
|
| 78 |
+
" .dataframe thead th {\n",
|
| 79 |
+
" text-align: right;\n",
|
| 80 |
+
" }\n",
|
| 81 |
+
"</style>\n",
|
| 82 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 83 |
+
" <thead>\n",
|
| 84 |
+
" <tr style=\"text-align: right;\">\n",
|
| 85 |
+
" <th></th>\n",
|
| 86 |
+
" <th>document</th>\n",
|
| 87 |
+
" <th>comment_id</th>\n",
|
| 88 |
+
" <th>type</th>\n",
|
| 89 |
+
" <th>start</th>\n",
|
| 90 |
+
" <th>end</th>\n",
|
| 91 |
+
" </tr>\n",
|
| 92 |
+
" </thead>\n",
|
| 93 |
+
" <tbody>\n",
|
| 94 |
+
" <tr>\n",
|
| 95 |
+
" <th>0</th>\n",
|
| 96 |
+
" <td>NDY-004</td>\n",
|
| 97 |
+
" <td>2</td>\n",
|
| 98 |
+
" <td>compliment</td>\n",
|
| 99 |
+
" <td>0</td>\n",
|
| 100 |
+
" <td>21</td>\n",
|
| 101 |
+
" </tr>\n",
|
| 102 |
+
" <tr>\n",
|
| 103 |
+
" <th>1</th>\n",
|
| 104 |
+
" <td>NDY-004</td>\n",
|
| 105 |
+
" <td>4</td>\n",
|
| 106 |
+
" <td>affection declaration</td>\n",
|
| 107 |
+
" <td>0</td>\n",
|
| 108 |
+
" <td>19</td>\n",
|
| 109 |
+
" </tr>\n",
|
| 110 |
+
" <tr>\n",
|
| 111 |
+
" <th>2</th>\n",
|
| 112 |
+
" <td>NDY-004</td>\n",
|
| 113 |
+
" <td>5</td>\n",
|
| 114 |
+
" <td>affection declaration</td>\n",
|
| 115 |
+
" <td>0</td>\n",
|
| 116 |
+
" <td>25</td>\n",
|
| 117 |
+
" </tr>\n",
|
| 118 |
+
" <tr>\n",
|
| 119 |
+
" <th>3</th>\n",
|
| 120 |
+
" <td>NDY-004</td>\n",
|
| 121 |
+
" <td>5</td>\n",
|
| 122 |
+
" <td>affection declaration</td>\n",
|
| 123 |
+
" <td>26</td>\n",
|
| 124 |
+
" <td>56</td>\n",
|
| 125 |
+
" </tr>\n",
|
| 126 |
+
" <tr>\n",
|
| 127 |
+
" <th>4</th>\n",
|
| 128 |
+
" <td>NDY-004</td>\n",
|
| 129 |
+
" <td>5</td>\n",
|
| 130 |
+
" <td>positive feedback</td>\n",
|
| 131 |
+
" <td>57</td>\n",
|
| 132 |
+
" <td>71</td>\n",
|
| 133 |
+
" </tr>\n",
|
| 134 |
+
" <tr>\n",
|
| 135 |
+
" <th>...</th>\n",
|
| 136 |
+
" <td>...</td>\n",
|
| 137 |
+
" <td>...</td>\n",
|
| 138 |
+
" <td>...</td>\n",
|
| 139 |
+
" <td>...</td>\n",
|
| 140 |
+
" <td>...</td>\n",
|
| 141 |
+
" </tr>\n",
|
| 142 |
+
" <tr>\n",
|
| 143 |
+
" <th>5498</th>\n",
|
| 144 |
+
" <td>NDY-203</td>\n",
|
| 145 |
+
" <td>526</td>\n",
|
| 146 |
+
" <td>affection declaration</td>\n",
|
| 147 |
+
" <td>0</td>\n",
|
| 148 |
+
" <td>17</td>\n",
|
| 149 |
+
" </tr>\n",
|
| 150 |
+
" <tr>\n",
|
| 151 |
+
" <th>5499</th>\n",
|
| 152 |
+
" <td>NDY-203</td>\n",
|
| 153 |
+
" <td>526</td>\n",
|
| 154 |
+
" <td>positive feedback</td>\n",
|
| 155 |
+
" <td>30</td>\n",
|
| 156 |
+
" <td>59</td>\n",
|
| 157 |
+
" </tr>\n",
|
| 158 |
+
" <tr>\n",
|
| 159 |
+
" <th>5500</th>\n",
|
| 160 |
+
" <td>NDY-203</td>\n",
|
| 161 |
+
" <td>526</td>\n",
|
| 162 |
+
" <td>positive feedback</td>\n",
|
| 163 |
+
" <td>64</td>\n",
|
| 164 |
+
" <td>104</td>\n",
|
| 165 |
+
" </tr>\n",
|
| 166 |
+
" <tr>\n",
|
| 167 |
+
" <th>5501</th>\n",
|
| 168 |
+
" <td>NDY-203</td>\n",
|
| 169 |
+
" <td>526</td>\n",
|
| 170 |
+
" <td>affection declaration</td>\n",
|
| 171 |
+
" <td>105</td>\n",
|
| 172 |
+
" <td>106</td>\n",
|
| 173 |
+
" </tr>\n",
|
| 174 |
+
" <tr>\n",
|
| 175 |
+
" <th>5502</th>\n",
|
| 176 |
+
" <td>NDY-203</td>\n",
|
| 177 |
+
" <td>526</td>\n",
|
| 178 |
+
" <td>affection declaration</td>\n",
|
| 179 |
+
" <td>105</td>\n",
|
| 180 |
+
" <td>114</td>\n",
|
| 181 |
+
" </tr>\n",
|
| 182 |
+
" </tbody>\n",
|
| 183 |
+
"</table>\n",
|
| 184 |
+
"<p>5503 rows × 5 columns</p>\n",
|
| 185 |
+
"</div>"
|
| 186 |
+
],
|
| 187 |
+
"text/plain": [
|
| 188 |
+
" document comment_id type start end\n",
|
| 189 |
+
"0 NDY-004 2 compliment 0 21\n",
|
| 190 |
+
"1 NDY-004 4 affection declaration 0 19\n",
|
| 191 |
+
"2 NDY-004 5 affection declaration 0 25\n",
|
| 192 |
+
"3 NDY-004 5 affection declaration 26 56\n",
|
| 193 |
+
"4 NDY-004 5 positive feedback 57 71\n",
|
| 194 |
+
"... ... ... ... ... ...\n",
|
| 195 |
+
"5498 NDY-203 526 affection declaration 0 17\n",
|
| 196 |
+
"5499 NDY-203 526 positive feedback 30 59\n",
|
| 197 |
+
"5500 NDY-203 526 positive feedback 64 104\n",
|
| 198 |
+
"5501 NDY-203 526 affection declaration 105 106\n",
|
| 199 |
+
"5502 NDY-203 526 affection declaration 105 114\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"[5503 rows x 5 columns]"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
"execution_count": 57,
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"output_type": "execute_result"
|
| 207 |
+
}
|
| 208 |
+
],
|
| 209 |
+
"source": [
|
| 210 |
+
"test_comments_spans"
|
| 211 |
+
]
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"cell_type": "code",
|
| 215 |
+
"execution_count": 58,
|
| 216 |
+
"id": "2b63b3b12b9648f6",
|
| 217 |
+
"metadata": {
|
| 218 |
+
"ExecuteTime": {
|
| 219 |
+
"end_time": "2025-06-27T22:07:50.819958Z",
|
| 220 |
+
"start_time": "2025-06-27T22:07:50.699928Z"
|
| 221 |
+
}
|
| 222 |
+
},
|
| 223 |
+
"outputs": [
|
| 224 |
+
{
|
| 225 |
+
"data": {
|
| 226 |
+
"text/html": [
|
| 227 |
+
"<div>\n",
|
| 228 |
+
"<style scoped>\n",
|
| 229 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 230 |
+
" vertical-align: middle;\n",
|
| 231 |
+
" }\n",
|
| 232 |
+
"\n",
|
| 233 |
+
" .dataframe tbody tr th {\n",
|
| 234 |
+
" vertical-align: top;\n",
|
| 235 |
+
" }\n",
|
| 236 |
+
"\n",
|
| 237 |
+
" .dataframe thead th {\n",
|
| 238 |
+
" text-align: right;\n",
|
| 239 |
+
" }\n",
|
| 240 |
+
"</style>\n",
|
| 241 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 242 |
+
" <thead>\n",
|
| 243 |
+
" <tr style=\"text-align: right;\">\n",
|
| 244 |
+
" <th></th>\n",
|
| 245 |
+
" <th>document</th>\n",
|
| 246 |
+
" <th>comment_id</th>\n",
|
| 247 |
+
" <th>comment</th>\n",
|
| 248 |
+
" <th>predicted_labels</th>\n",
|
| 249 |
+
" <th>predicted_probs</th>\n",
|
| 250 |
+
" <th>offset_mapping</th>\n",
|
| 251 |
+
" <th>text_tokens</th>\n",
|
| 252 |
+
" <th>predicted_spans</th>\n",
|
| 253 |
+
" </tr>\n",
|
| 254 |
+
" </thead>\n",
|
| 255 |
+
" <tbody>\n",
|
| 256 |
+
" <tr>\n",
|
| 257 |
+
" <th>0</th>\n",
|
| 258 |
+
" <td>NDY-004</td>\n",
|
| 259 |
+
" <td>1</td>\n",
|
| 260 |
+
" <td>Lol i love lochis</td>\n",
|
| 261 |
+
" <td>[0, 0, 0, 0, 0, 0, 0, 0]</td>\n",
|
| 262 |
+
" <td>[[0.99999654, 1.7456429e-07, 1.6115715e-07, 1....</td>\n",
|
| 263 |
+
" <td>[[0, 0], [0, 1], [1, 3], [4, 5], [6, 10], [11,...</td>\n",
|
| 264 |
+
" <td>[▁L, ol, ▁i, ▁love, ▁loc, his]</td>\n",
|
| 265 |
+
" <td>[]</td>\n",
|
| 266 |
+
" </tr>\n",
|
| 267 |
+
" <tr>\n",
|
| 268 |
+
" <th>1</th>\n",
|
| 269 |
+
" <td>NDY-004</td>\n",
|
| 270 |
+
" <td>2</td>\n",
|
| 271 |
+
" <td>ihr singt voll gut :)</td>\n",
|
| 272 |
+
" <td>[0, 2, 12, 12, 12, 12, 12, 0]</td>\n",
|
| 273 |
+
" <td>[[0.9999976, 1.1218729e-07, 1.239344e-07, 1.50...</td>\n",
|
| 274 |
+
" <td>[[0, 0], [0, 3], [4, 8], [8, 9], [10, 14], [15...</td>\n",
|
| 275 |
+
" <td>[▁ihr, ▁sing, t, ▁voll, ▁gut, ▁:)]</td>\n",
|
| 276 |
+
" <td>[{'type': 'compliment', 'start': 0, 'end': 21,...</td>\n",
|
| 277 |
+
" </tr>\n",
|
| 278 |
+
" <tr>\n",
|
| 279 |
+
" <th>2</th>\n",
|
| 280 |
+
" <td>NDY-004</td>\n",
|
| 281 |
+
" <td>3</td>\n",
|
| 282 |
+
" <td>Junge fick dich</td>\n",
|
| 283 |
+
" <td>[0, 0, 0, 0, 0, 0]</td>\n",
|
| 284 |
+
" <td>[[0.9999981, 5.8623616e-08, 1.05891374e-07, 1....</td>\n",
|
| 285 |
+
" <td>[[0, 0], [0, 4], [4, 5], [6, 10], [11, 15], [0...</td>\n",
|
| 286 |
+
" <td>[▁Jung, e, ▁fick, ▁dich]</td>\n",
|
| 287 |
+
" <td>[]</td>\n",
|
| 288 |
+
" </tr>\n",
|
| 289 |
+
" <tr>\n",
|
| 290 |
+
" <th>3</th>\n",
|
| 291 |
+
" <td>NDY-004</td>\n",
|
| 292 |
+
" <td>4</td>\n",
|
| 293 |
+
" <td>Ihr seit die besten</td>\n",
|
| 294 |
+
" <td>[0, 3, 13, 13, 13, 0]</td>\n",
|
| 295 |
+
" <td>[[0.99999774, 1.6417343e-07, 1.384722e-07, 1.1...</td>\n",
|
| 296 |
+
" <td>[[0, 0], [0, 3], [4, 8], [9, 12], [13, 19], [0...</td>\n",
|
| 297 |
+
" <td>[▁Ihr, ▁seit, ▁die, ▁besten]</td>\n",
|
| 298 |
+
" <td>[{'type': 'affection declaration', 'start': 0,...</td>\n",
|
| 299 |
+
" </tr>\n",
|
| 300 |
+
" <tr>\n",
|
| 301 |
+
" <th>4</th>\n",
|
| 302 |
+
" <td>NDY-004</td>\n",
|
| 303 |
+
" <td>5</td>\n",
|
| 304 |
+
" <td>ihr seit die ALLER besten ich finde euch soooo...</td>\n",
|
| 305 |
+
" <td>[0, 3, 13, 13, 13, 13, 13, 3, 13, 13, 13, 13, ...</td>\n",
|
| 306 |
+
" <td>[[0.99999785, 1.2960982e-07, 1.4320104e-07, 1....</td>\n",
|
| 307 |
+
" <td>[[0, 0], [0, 3], [4, 8], [9, 12], [13, 17], [1...</td>\n",
|
| 308 |
+
" <td>[▁ihr, ▁seit, ▁die, ▁ALLE, R, ▁besten, ▁ich, ▁...</td>\n",
|
| 309 |
+
" <td>[{'type': 'affection declaration', 'start': 0,...</td>\n",
|
| 310 |
+
" </tr>\n",
|
| 311 |
+
" <tr>\n",
|
| 312 |
+
" <th>...</th>\n",
|
| 313 |
+
" <td>...</td>\n",
|
| 314 |
+
" <td>...</td>\n",
|
| 315 |
+
" <td>...</td>\n",
|
| 316 |
+
" <td>...</td>\n",
|
| 317 |
+
" <td>...</td>\n",
|
| 318 |
+
" <td>...</td>\n",
|
| 319 |
+
" <td>...</td>\n",
|
| 320 |
+
" <td>...</td>\n",
|
| 321 |
+
" </tr>\n",
|
| 322 |
+
" <tr>\n",
|
| 323 |
+
" <th>9224</th>\n",
|
| 324 |
+
" <td>NDY-203</td>\n",
|
| 325 |
+
" <td>522</td>\n",
|
| 326 |
+
" <td>hihi kannst du mich grüßen 💕 👋 😍 Achso wusstes...</td>\n",
|
| 327 |
+
" <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 11, 0, 11, 11, ...</td>\n",
|
| 328 |
+
" <td>[[0.99999774, 1.8107521e-07, 1.0220851e-07, 9....</td>\n",
|
| 329 |
+
" <td>[[0, 0], [0, 4], [5, 11], [12, 14], [15, 19], ...</td>\n",
|
| 330 |
+
" <td>[▁hihi, ▁kannst, ▁du, ▁mich, ▁gr, üß, en, ▁, 💕...</td>\n",
|
| 331 |
+
" <td>[{'type': 'positive feedback', 'start': 27, 'e...</td>\n",
|
| 332 |
+
" </tr>\n",
|
| 333 |
+
" <tr>\n",
|
| 334 |
+
" <th>9225</th>\n",
|
| 335 |
+
" <td>NDY-203</td>\n",
|
| 336 |
+
" <td>523</td>\n",
|
| 337 |
+
" <td>#Glocke aktiviert 👑 Ich liebe deine Videos 💍 💎...</td>\n",
|
| 338 |
+
" <td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 11, 11, 11, 11,...</td>\n",
|
| 339 |
+
" <td>[[0.9999976, 1.1908668e-07, 8.492378e-08, 6.60...</td>\n",
|
| 340 |
+
" <td>[[0, 0], [0, 1], [1, 2], [2, 6], [6, 7], [8, 1...</td>\n",
|
| 341 |
+
" <td>[▁#, G, lock, e, ▁aktiv, iert, ▁, 👑, ▁Ich, ▁li...</td>\n",
|
| 342 |
+
" <td>[{'type': 'positive feedback', 'start': 20, 'e...</td>\n",
|
| 343 |
+
" </tr>\n",
|
| 344 |
+
" <tr>\n",
|
| 345 |
+
" <th>9226</th>\n",
|
| 346 |
+
" <td>NDY-203</td>\n",
|
| 347 |
+
" <td>524</td>\n",
|
| 348 |
+
" <td>Bist die beste ❤ Bitte Grüße mich 💕 ❤ 😘 😍</td>\n",
|
| 349 |
+
" <td>[0, 3, 13, 13, 13, 13, 0, 0, 0, 1, 1, 11, 11, ...</td>\n",
|
| 350 |
+
" <td>[[0.9999974, 2.1362885e-07, 1.2580301e-07, 9.5...</td>\n",
|
| 351 |
+
" <td>[[0, 0], [0, 3], [3, 4], [5, 8], [9, 14], [15,...</td>\n",
|
| 352 |
+
" <td>[▁Bis, t, ▁die, ▁beste, ▁❤, ▁Bitte, ▁Grüße, ▁m...</td>\n",
|
| 353 |
+
" <td>[{'type': 'affection declaration', 'start': 0,...</td>\n",
|
| 354 |
+
" </tr>\n",
|
| 355 |
+
" <tr>\n",
|
| 356 |
+
" <th>9227</th>\n",
|
| 357 |
+
" <td>NDY-203</td>\n",
|
| 358 |
+
" <td>525</td>\n",
|
| 359 |
+
" <td>Hi Bonny ❤️ War letztens auf'm Flughafen , und...</td>\n",
|
| 360 |
+
" <td>[0, 0, 0, 0, 1, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td>\n",
|
| 361 |
+
" <td>[[0.99999523, 6.63842e-07, 2.0147786e-07, 1.16...</td>\n",
|
| 362 |
+
" <td>[[0, 0], [0, 2], [3, 6], [6, 8], [9, 10], [10,...</td>\n",
|
| 363 |
+
" <td>[▁Hi, ▁Bon, ny, ▁❤, ️, ▁War, ▁letzten, s, ▁auf...</td>\n",
|
| 364 |
+
" <td>[{'type': 'positive feedback', 'start': 9, 'en...</td>\n",
|
| 365 |
+
" </tr>\n",
|
| 366 |
+
" <tr>\n",
|
| 367 |
+
" <th>9228</th>\n",
|
| 368 |
+
" <td>NDY-203</td>\n",
|
| 369 |
+
" <td>526</td>\n",
|
| 370 |
+
" <td>du bist die beste ich bin neu ich hab dich sof...</td>\n",
|
| 371 |
+
" <td>[0, 3, 13, 13, 13, 0, 0, 0, 1, 11, 11, 11, 11,...</td>\n",
|
| 372 |
+
" <td>[[0.999997, 3.4811254e-07, 7.750037e-08, 7.272...</td>\n",
|
| 373 |
+
" <td>[[0, 0], [0, 2], [3, 7], [8, 11], [12, 17], [1...</td>\n",
|
| 374 |
+
" <td>[▁du, ▁bist, ▁die, ▁beste, ▁ich, ▁bin, ▁neu, ▁...</td>\n",
|
| 375 |
+
" <td>[{'type': 'affection declaration', 'start': 0,...</td>\n",
|
| 376 |
+
" </tr>\n",
|
| 377 |
+
" </tbody>\n",
|
| 378 |
+
"</table>\n",
|
| 379 |
+
"<p>9229 rows × 8 columns</p>\n",
|
| 380 |
+
"</div>"
|
| 381 |
+
],
|
| 382 |
+
"text/plain": [
|
| 383 |
+
" document comment_id comment \\\n",
|
| 384 |
+
"0 NDY-004 1 Lol i love lochis \n",
|
| 385 |
+
"1 NDY-004 2 ihr singt voll gut :) \n",
|
| 386 |
+
"2 NDY-004 3 Junge fick dich \n",
|
| 387 |
+
"3 NDY-004 4 Ihr seit die besten \n",
|
| 388 |
+
"4 NDY-004 5 ihr seit die ALLER besten ich finde euch soooo... \n",
|
| 389 |
+
"... ... ... ... \n",
|
| 390 |
+
"9224 NDY-203 522 hihi kannst du mich grüßen 💕 👋 😍 Achso wusstes... \n",
|
| 391 |
+
"9225 NDY-203 523 #Glocke aktiviert 👑 Ich liebe deine Videos 💍 💎... \n",
|
| 392 |
+
"9226 NDY-203 524 Bist die beste ❤ Bitte Grüße mich 💕 ❤ 😘 😍 \n",
|
| 393 |
+
"9227 NDY-203 525 Hi Bonny ❤️ War letztens auf'm Flughafen , und... \n",
|
| 394 |
+
"9228 NDY-203 526 du bist die beste ich bin neu ich hab dich sof... \n",
|
| 395 |
+
"\n",
|
| 396 |
+
" predicted_labels \\\n",
|
| 397 |
+
"0 [0, 0, 0, 0, 0, 0, 0, 0] \n",
|
| 398 |
+
"1 [0, 2, 12, 12, 12, 12, 12, 0] \n",
|
| 399 |
+
"2 [0, 0, 0, 0, 0, 0] \n",
|
| 400 |
+
"3 [0, 3, 13, 13, 13, 0] \n",
|
| 401 |
+
"4 [0, 3, 13, 13, 13, 13, 13, 3, 13, 13, 13, 13, ... \n",
|
| 402 |
+
"... ... \n",
|
| 403 |
+
"9224 [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 11, 0, 11, 11, ... \n",
|
| 404 |
+
"9225 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 11, 11, 11, 11,... \n",
|
| 405 |
+
"9226 [0, 3, 13, 13, 13, 13, 0, 0, 0, 1, 1, 11, 11, ... \n",
|
| 406 |
+
"9227 [0, 0, 0, 0, 1, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n",
|
| 407 |
+
"9228 [0, 3, 13, 13, 13, 0, 0, 0, 1, 11, 11, 11, 11,... \n",
|
| 408 |
+
"\n",
|
| 409 |
+
" predicted_probs \\\n",
|
| 410 |
+
"0 [[0.99999654, 1.7456429e-07, 1.6115715e-07, 1.... \n",
|
| 411 |
+
"1 [[0.9999976, 1.1218729e-07, 1.239344e-07, 1.50... \n",
|
| 412 |
+
"2 [[0.9999981, 5.8623616e-08, 1.05891374e-07, 1.... \n",
|
| 413 |
+
"3 [[0.99999774, 1.6417343e-07, 1.384722e-07, 1.1... \n",
|
| 414 |
+
"4 [[0.99999785, 1.2960982e-07, 1.4320104e-07, 1.... \n",
|
| 415 |
+
"... ... \n",
|
| 416 |
+
"9224 [[0.99999774, 1.8107521e-07, 1.0220851e-07, 9.... \n",
|
| 417 |
+
"9225 [[0.9999976, 1.1908668e-07, 8.492378e-08, 6.60... \n",
|
| 418 |
+
"9226 [[0.9999974, 2.1362885e-07, 1.2580301e-07, 9.5... \n",
|
| 419 |
+
"9227 [[0.99999523, 6.63842e-07, 2.0147786e-07, 1.16... \n",
|
| 420 |
+
"9228 [[0.999997, 3.4811254e-07, 7.750037e-08, 7.272... \n",
|
| 421 |
+
"\n",
|
| 422 |
+
" offset_mapping \\\n",
|
| 423 |
+
"0 [[0, 0], [0, 1], [1, 3], [4, 5], [6, 10], [11,... \n",
|
| 424 |
+
"1 [[0, 0], [0, 3], [4, 8], [8, 9], [10, 14], [15... \n",
|
| 425 |
+
"2 [[0, 0], [0, 4], [4, 5], [6, 10], [11, 15], [0... \n",
|
| 426 |
+
"3 [[0, 0], [0, 3], [4, 8], [9, 12], [13, 19], [0... \n",
|
| 427 |
+
"4 [[0, 0], [0, 3], [4, 8], [9, 12], [13, 17], [1... \n",
|
| 428 |
+
"... ... \n",
|
| 429 |
+
"9224 [[0, 0], [0, 4], [5, 11], [12, 14], [15, 19], ... \n",
|
| 430 |
+
"9225 [[0, 0], [0, 1], [1, 2], [2, 6], [6, 7], [8, 1... \n",
|
| 431 |
+
"9226 [[0, 0], [0, 3], [3, 4], [5, 8], [9, 14], [15,... \n",
|
| 432 |
+
"9227 [[0, 0], [0, 2], [3, 6], [6, 8], [9, 10], [10,... \n",
|
| 433 |
+
"9228 [[0, 0], [0, 2], [3, 7], [8, 11], [12, 17], [1... \n",
|
| 434 |
+
"\n",
|
| 435 |
+
" text_tokens \\\n",
|
| 436 |
+
"0 [▁L, ol, ▁i, ▁love, ▁loc, his] \n",
|
| 437 |
+
"1 [▁ihr, ▁sing, t, ▁voll, ▁gut, ▁:)] \n",
|
| 438 |
+
"2 [▁Jung, e, ▁fick, ▁dich] \n",
|
| 439 |
+
"3 [▁Ihr, ▁seit, ▁die, ▁besten] \n",
|
| 440 |
+
"4 [▁ihr, ▁seit, ▁die, ▁ALLE, R, ▁besten, ▁ich, ▁... \n",
|
| 441 |
+
"... ... \n",
|
| 442 |
+
"9224 [▁hihi, ▁kannst, ▁du, ▁mich, ▁gr, üß, en, ▁, 💕... \n",
|
| 443 |
+
"9225 [▁#, G, lock, e, ▁aktiv, iert, ▁, 👑, ▁Ich, ▁li... \n",
|
| 444 |
+
"9226 [▁Bis, t, ▁die, ▁beste, ▁❤, ▁Bitte, ▁Grüße, ▁m... \n",
|
| 445 |
+
"9227 [▁Hi, ▁Bon, ny, ▁❤, ️, ▁War, ▁letzten, s, ▁auf... \n",
|
| 446 |
+
"9228 [▁du, ▁bist, ▁die, ▁beste, ▁ich, ▁bin, ▁neu, ▁... \n",
|
| 447 |
+
"\n",
|
| 448 |
+
" predicted_spans \n",
|
| 449 |
+
"0 [] \n",
|
| 450 |
+
"1 [{'type': 'compliment', 'start': 0, 'end': 21,... \n",
|
| 451 |
+
"2 [] \n",
|
| 452 |
+
"3 [{'type': 'affection declaration', 'start': 0,... \n",
|
| 453 |
+
"4 [{'type': 'affection declaration', 'start': 0,... \n",
|
| 454 |
+
"... ... \n",
|
| 455 |
+
"9224 [{'type': 'positive feedback', 'start': 27, 'e... \n",
|
| 456 |
+
"9225 [{'type': 'positive feedback', 'start': 20, 'e... \n",
|
| 457 |
+
"9226 [{'type': 'affection declaration', 'start': 0,... \n",
|
| 458 |
+
"9227 [{'type': 'positive feedback', 'start': 9, 'en... \n",
|
| 459 |
+
"9228 [{'type': 'affection declaration', 'start': 0,... \n",
|
| 460 |
+
"\n",
|
| 461 |
+
"[9229 rows x 8 columns]"
|
| 462 |
+
]
|
| 463 |
+
},
|
| 464 |
+
"execution_count": 58,
|
| 465 |
+
"metadata": {},
|
| 466 |
+
"output_type": "execute_result"
|
| 467 |
+
}
|
| 468 |
+
],
|
| 469 |
+
"source": [
|
| 470 |
+
"test_comments"
|
| 471 |
+
]
|
| 472 |
+
},
|
| 473 |
+
{
|
| 474 |
+
"cell_type": "code",
|
| 475 |
+
"execution_count": 60,
|
| 476 |
+
"id": "263a51fec4f4672",
|
| 477 |
+
"metadata": {
|
| 478 |
+
"ExecuteTime": {
|
| 479 |
+
"end_time": "2025-06-27T22:09:58.052637Z",
|
| 480 |
+
"start_time": "2025-06-27T22:09:57.997729Z"
|
| 481 |
+
}
|
| 482 |
+
},
|
| 483 |
+
"outputs": [],
|
| 484 |
+
"source": [
|
| 485 |
+
"test_comments['has_spans'] = test_comments.apply(lambda x: len(x['predicted_spans']) > 0, axis=1)"
|
| 486 |
+
]
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"cell_type": "code",
|
| 490 |
+
"execution_count": 63,
|
| 491 |
+
"id": "5fa67bbeb303ca3a",
|
| 492 |
+
"metadata": {
|
| 493 |
+
"ExecuteTime": {
|
| 494 |
+
"end_time": "2025-06-27T22:10:35.264094Z",
|
| 495 |
+
"start_time": "2025-06-27T22:10:35.260301Z"
|
| 496 |
+
}
|
| 497 |
+
},
|
| 498 |
+
"outputs": [],
|
| 499 |
+
"source": [
|
| 500 |
+
"test_comments['flausch'] = test_comments['has_spans'].map({True: 'yes', False: 'no'})"
|
| 501 |
+
]
|
| 502 |
+
},
|
| 503 |
+
{
|
| 504 |
+
"cell_type": "code",
|
| 505 |
+
"execution_count": 66,
|
| 506 |
+
"id": "fd7679e665286b70",
|
| 507 |
+
"metadata": {
|
| 508 |
+
"ExecuteTime": {
|
| 509 |
+
"end_time": "2025-06-27T22:11:57.164479Z",
|
| 510 |
+
"start_time": "2025-06-27T22:11:57.150708Z"
|
| 511 |
+
}
|
| 512 |
+
},
|
| 513 |
+
"outputs": [],
|
| 514 |
+
"source": [
|
| 515 |
+
"test_comments[[\"document\",\"comment_id\",\"flausch\"]].to_csv(f'./submissions/task1-predicted.csv', index=False)"
|
| 516 |
+
]
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"cell_type": "code",
|
| 520 |
+
"execution_count": 68,
|
| 521 |
+
"id": "bd9d8b153b8d27ed",
|
| 522 |
+
"metadata": {
|
| 523 |
+
"ExecuteTime": {
|
| 524 |
+
"end_time": "2025-06-27T22:12:25.303426Z",
|
| 525 |
+
"start_time": "2025-06-27T22:12:24.850361Z"
|
| 526 |
+
}
|
| 527 |
+
},
|
| 528 |
+
"outputs": [
|
| 529 |
+
{
|
| 530 |
+
"name": "stderr",
|
| 531 |
+
"output_type": "stream",
|
| 532 |
+
"text": [
|
| 533 |
+
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
| 534 |
+
"To disable this warning, you can either:\n",
|
| 535 |
+
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
| 536 |
+
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
| 537 |
+
]
|
| 538 |
+
}
|
| 539 |
+
],
|
| 540 |
+
"source": [
|
| 541 |
+
"!cp './submissions/task1-predicted.csv' './submissions/subtask1_submission2.csv'"
|
| 542 |
+
]
|
| 543 |
+
},
|
| 544 |
+
{
|
| 545 |
+
"cell_type": "code",
|
| 546 |
+
"execution_count": 70,
|
| 547 |
+
"id": "5a2738b19dcd4292",
|
| 548 |
+
"metadata": {
|
| 549 |
+
"ExecuteTime": {
|
| 550 |
+
"end_time": "2025-06-27T22:12:43.388207Z",
|
| 551 |
+
"start_time": "2025-06-27T22:12:42.945847Z"
|
| 552 |
+
}
|
| 553 |
+
},
|
| 554 |
+
"outputs": [
|
| 555 |
+
{
|
| 556 |
+
"name": "stdout",
|
| 557 |
+
"output_type": "stream",
|
| 558 |
+
"text": [
|
| 559 |
+
"document,comment_id,flausch\r\n",
|
| 560 |
+
"NDY-004,1,no\r\n",
|
| 561 |
+
"NDY-004,2,yes\r\n",
|
| 562 |
+
"NDY-004,3,no\r\n",
|
| 563 |
+
"NDY-004,4,yes\r\n",
|
| 564 |
+
"NDY-004,5,yes\r\n",
|
| 565 |
+
"NDY-004,6,yes\r\n",
|
| 566 |
+
"NDY-004,7,no\r\n",
|
| 567 |
+
"NDY-004,8,yes\r\n",
|
| 568 |
+
"NDY-004,9,no\r\n"
|
| 569 |
+
]
|
| 570 |
+
},
|
| 571 |
+
{
|
| 572 |
+
"name": "stderr",
|
| 573 |
+
"output_type": "stream",
|
| 574 |
+
"text": [
|
| 575 |
+
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
| 576 |
+
"To disable this warning, you can either:\n",
|
| 577 |
+
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
| 578 |
+
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
| 579 |
+
]
|
| 580 |
+
}
|
| 581 |
+
],
|
| 582 |
+
"source": [
|
| 583 |
+
"!head -n 10 './submissions/task1-predicted.csv'"
|
| 584 |
+
]
|
| 585 |
+
}
|
| 586 |
+
],
|
| 587 |
+
"metadata": {
|
| 588 |
+
"kernelspec": {
|
| 589 |
+
"display_name": "Python 3",
|
| 590 |
+
"language": "python",
|
| 591 |
+
"name": "python3"
|
| 592 |
+
},
|
| 593 |
+
"language_info": {
|
| 594 |
+
"codemirror_mode": {
|
| 595 |
+
"name": "ipython",
|
| 596 |
+
"version": 2
|
| 597 |
+
},
|
| 598 |
+
"file_extension": ".py",
|
| 599 |
+
"mimetype": "text/x-python",
|
| 600 |
+
"name": "python",
|
| 601 |
+
"nbconvert_exporter": "python",
|
| 602 |
+
"pygments_lexer": "ipython2",
|
| 603 |
+
"version": "2.7.6"
|
| 604 |
+
}
|
| 605 |
+
},
|
| 606 |
+
"nbformat": 4,
|
| 607 |
+
"nbformat_minor": 5
|
| 608 |
+
}
|
subtask_1/submission_subtask1.ipynb
ADDED
|
@@ -0,0 +1,719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"metadata": {},
|
| 5 |
+
"cell_type": "markdown",
|
| 6 |
+
"source": [
|
| 7 |
+
"## Experiment 019-4\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"SVM mit RBF Kernel, C=5 und Gamma=0.0002"
|
| 10 |
+
],
|
| 11 |
+
"id": "8d9679176b5367c7"
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"id": "initial_id",
|
| 16 |
+
"metadata": {
|
| 17 |
+
"collapsed": true,
|
| 18 |
+
"ExecuteTime": {
|
| 19 |
+
"end_time": "2025-06-23T18:30:56.081332Z",
|
| 20 |
+
"start_time": "2025-06-23T18:30:55.935044Z"
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"source": [
|
| 24 |
+
"import os\n",
|
| 25 |
+
"from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, make_scorer, classification_report\n",
|
| 26 |
+
"from sklearn.model_selection import StratifiedKFold, train_test_split, GridSearchCV\n",
|
| 27 |
+
"from sklearn.pipeline import Pipeline\n",
|
| 28 |
+
"from sklearn.preprocessing import StandardScaler\n",
|
| 29 |
+
"from sklearn.svm import SVC\n",
|
| 30 |
+
"import time\n",
|
| 31 |
+
"import pickle\n",
|
| 32 |
+
"import numpy as np\n",
|
| 33 |
+
"import pandas as pd\n",
|
| 34 |
+
"import torch\n",
|
| 35 |
+
"from torch import Tensor\n",
|
| 36 |
+
"from transformers import AutoModel, AutoTokenizer\n",
|
| 37 |
+
"from transformers.utils import is_flash_attn_2_available\n",
|
| 38 |
+
"import wandb\n",
|
| 39 |
+
"from wandb import AlertLevel\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
|
| 43 |
+
"os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
|
| 44 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
|
| 45 |
+
"os.environ[\"WANDB_PROJECT\"] = \"GermEval2025-Substask1\"\n",
|
| 46 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"false\"\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"if torch.cuda.is_available():\n",
|
| 49 |
+
" device = torch.device('cuda')\n",
|
| 50 |
+
"else:\n",
|
| 51 |
+
" device = torch.device('cpu')\n",
|
| 52 |
+
" print(\"CUDA not available, using CPU\")\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"experiment_name = \"exp019-4\"\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"testing_mode = False\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"# Load data\n",
|
| 59 |
+
"comments = pd.read_csv(\"./share-GermEval2025-data/Data/training data/comments.csv\")\n",
|
| 60 |
+
"task1 = pd.read_csv(\"./share-GermEval2025-data/Data/training data/task1.csv\")\n",
|
| 61 |
+
"comments = comments.merge(task1, on=[\"document\", \"comment_id\"])\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"# Remove duplicates\n",
|
| 64 |
+
"df = comments.drop_duplicates(subset=['comment', 'flausch'])\n",
|
| 65 |
+
"df.reset_index(drop=True, inplace=True)"
|
| 66 |
+
],
|
| 67 |
+
"outputs": [],
|
| 68 |
+
"execution_count": 2
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"cell_type": "code",
|
| 73 |
+
"source": [
|
| 74 |
+
"def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:\n",
|
| 75 |
+
" left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])\n",
|
| 76 |
+
" if left_padding:\n",
|
| 77 |
+
" return last_hidden_states[:, -1]\n",
|
| 78 |
+
" else:\n",
|
| 79 |
+
" sequence_lengths = attention_mask.sum(dim=1) - 1\n",
|
| 80 |
+
" batch_size = last_hidden_states.shape[0]\n",
|
| 81 |
+
" return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"class Qwen3Embedder:\n",
|
| 84 |
+
" def __init__(self, model_name='Qwen/Qwen3-Embedding-8B', instruction=None, max_length=1024):\n",
|
| 85 |
+
" if instruction is None:\n",
|
| 86 |
+
" instruction = 'Classify a given comment as either flausch (a positive, supportive expression) or non-flausch.'\n",
|
| 87 |
+
" self.instruction = instruction\n",
|
| 88 |
+
"\n",
|
| 89 |
+
" if is_flash_attn_2_available():\n",
|
| 90 |
+
" self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)\n",
|
| 91 |
+
" else:\n",
|
| 92 |
+
" self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)\n",
|
| 93 |
+
"\n",
|
| 94 |
+
" self.model = self.model.cuda()\n",
|
| 95 |
+
" self.model.eval()\n",
|
| 96 |
+
"\n",
|
| 97 |
+
" self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left')\n",
|
| 98 |
+
" self.max_length = max_length\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" def get_detailed_instruct(self, query: str) -> str:\n",
|
| 101 |
+
" return f'Instruct: {self.instruction}\\nQuery:{query}'\n",
|
| 102 |
+
"\n",
|
| 103 |
+
" def encode_batch(self, texts, batch_size=32):\n",
|
| 104 |
+
" \"\"\"Encode texts in batches to handle memory efficiently\"\"\"\n",
|
| 105 |
+
" all_embeddings = []\n",
|
| 106 |
+
"\n",
|
| 107 |
+
" for i in range(0, len(texts), batch_size):\n",
|
| 108 |
+
" batch_texts = [self.get_detailed_instruct(comment) for comment in texts[i:i + batch_size]]\n",
|
| 109 |
+
"\n",
|
| 110 |
+
" # Tokenize batch\n",
|
| 111 |
+
" inputs = self.tokenizer(\n",
|
| 112 |
+
" batch_texts,\n",
|
| 113 |
+
" padding=True,\n",
|
| 114 |
+
" truncation=True,\n",
|
| 115 |
+
" max_length=self.max_length,\n",
|
| 116 |
+
" return_tensors='pt'\n",
|
| 117 |
+
" ).to(device)\n",
|
| 118 |
+
"\n",
|
| 119 |
+
" # Get embeddings\n",
|
| 120 |
+
" with torch.no_grad():\n",
|
| 121 |
+
" outputs = self.model(**inputs)\n",
|
| 122 |
+
" # Mean pooling\n",
|
| 123 |
+
" embeddings = last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])\n",
|
| 124 |
+
" #embeddings = embeddings.float()\n",
|
| 125 |
+
"\n",
|
| 126 |
+
" all_embeddings.append(embeddings.cpu().numpy())\n",
|
| 127 |
+
"\n",
|
| 128 |
+
" # Normalize embeddings (sollte ich?)\n",
|
| 129 |
+
" #import torch.nn.functional as F\n",
|
| 130 |
+
" #output = F.normalize(all_embeddings, p=2, dim=1)\n",
|
| 131 |
+
" return np.vstack(all_embeddings)\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"# Initialize embedder\n",
|
| 134 |
+
"print(\"Loading Qwen3 Embeddings v3...\")\n",
|
| 135 |
+
"embedder = Qwen3Embedder(instruction='Classify a given comment as either flausch (a positive, supportive expression) or non-flausch')\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"X, y = df[\"comment\"], df[\"flausch\"].map(dict(yes=1, no=0))\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"# load embeddings if they exist\n",
|
| 140 |
+
"embeddings_file = f'Qwen3-Embedding-8B-{experiment_name}.npy'\n",
|
| 141 |
+
"if os.path.exists(embeddings_file):\n",
|
| 142 |
+
" print(f\"Loading existing embeddings from {embeddings_file}\")\n",
|
| 143 |
+
" X_embeddings = np.load(embeddings_file)\n",
|
| 144 |
+
"else:\n",
|
| 145 |
+
" print(\"Embeddings not found, generating new embeddings...\")\n",
|
| 146 |
+
" # Encode texts in batches to avoid memory issues\n",
|
| 147 |
+
" X_embeddings = embedder.encode_batch(X.tolist(), batch_size=64)\n",
|
| 148 |
+
" print(f\"Generated embeddings with shape: {X_embeddings.shape}\")\n",
|
| 149 |
+
"\n",
|
| 150 |
+
" # save embeddings to avoid recomputation\n",
|
| 151 |
+
" np.save(embeddings_file, X_embeddings)\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"pipe = Pipeline([\n",
|
| 154 |
+
" (\"scaler\", StandardScaler()),\n",
|
| 155 |
+
" (\"svm\", SVC(random_state=42, C=5, gamma=0.0002, cache_size=2000))\n",
|
| 156 |
+
"])\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"f1_pos_scorer = make_scorer(f1_score, pos_label=1, average='binary')\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"X_train = X_embeddings\n",
|
| 161 |
+
"y_train = y\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"pipe.fit(X_train, y_train)"
|
| 164 |
+
],
|
| 165 |
+
"id": "59ef5a54cb69530f",
|
| 166 |
+
"outputs": [],
|
| 167 |
+
"execution_count": null
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"metadata": {
|
| 171 |
+
"ExecuteTime": {
|
| 172 |
+
"end_time": "2025-06-23T18:30:59.602524Z",
|
| 173 |
+
"start_time": "2025-06-23T18:30:59.570290Z"
|
| 174 |
+
}
|
| 175 |
+
},
|
| 176 |
+
"cell_type": "code",
|
| 177 |
+
"source": [
|
| 178 |
+
"test_data: pd.DataFrame = pd.read_csv(\"./share-GermEval2025-data/Data/test data/comments.csv\")\n",
|
| 179 |
+
"test_data"
|
| 180 |
+
],
|
| 181 |
+
"id": "a842bfa29d59c84b",
|
| 182 |
+
"outputs": [
|
| 183 |
+
{
|
| 184 |
+
"data": {
|
| 185 |
+
"text/plain": [
|
| 186 |
+
" document comment_id comment\n",
|
| 187 |
+
"0 NDY-004 1 Lol i love lochis\n",
|
| 188 |
+
"1 NDY-004 2 ihr singt voll gut :)\n",
|
| 189 |
+
"2 NDY-004 3 Junge fick dich\n",
|
| 190 |
+
"3 NDY-004 4 Ihr seit die besten\n",
|
| 191 |
+
"4 NDY-004 5 ihr seit die ALLER besten ich finde euch soooo...\n",
|
| 192 |
+
"... ... ... ...\n",
|
| 193 |
+
"9224 NDY-203 522 hihi kannst du mich grüßen 💕 👋 😍 Achso wusstes...\n",
|
| 194 |
+
"9225 NDY-203 523 #Glocke aktiviert 👑 Ich liebe deine Videos 💍 💎...\n",
|
| 195 |
+
"9226 NDY-203 524 Bist die beste ❤ Bitte Grüße mich 💕 ❤ 😘 😍\n",
|
| 196 |
+
"9227 NDY-203 525 Hi Bonny ❤️ War letztens auf'm Flughafen , und...\n",
|
| 197 |
+
"9228 NDY-203 526 du bist die beste ich bin neu ich hab dich sof...\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"[9229 rows x 3 columns]"
|
| 200 |
+
],
|
| 201 |
+
"text/html": [
|
| 202 |
+
"<div>\n",
|
| 203 |
+
"<style scoped>\n",
|
| 204 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 205 |
+
" vertical-align: middle;\n",
|
| 206 |
+
" }\n",
|
| 207 |
+
"\n",
|
| 208 |
+
" .dataframe tbody tr th {\n",
|
| 209 |
+
" vertical-align: top;\n",
|
| 210 |
+
" }\n",
|
| 211 |
+
"\n",
|
| 212 |
+
" .dataframe thead th {\n",
|
| 213 |
+
" text-align: right;\n",
|
| 214 |
+
" }\n",
|
| 215 |
+
"</style>\n",
|
| 216 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 217 |
+
" <thead>\n",
|
| 218 |
+
" <tr style=\"text-align: right;\">\n",
|
| 219 |
+
" <th></th>\n",
|
| 220 |
+
" <th>document</th>\n",
|
| 221 |
+
" <th>comment_id</th>\n",
|
| 222 |
+
" <th>comment</th>\n",
|
| 223 |
+
" </tr>\n",
|
| 224 |
+
" </thead>\n",
|
| 225 |
+
" <tbody>\n",
|
| 226 |
+
" <tr>\n",
|
| 227 |
+
" <th>0</th>\n",
|
| 228 |
+
" <td>NDY-004</td>\n",
|
| 229 |
+
" <td>1</td>\n",
|
| 230 |
+
" <td>Lol i love lochis</td>\n",
|
| 231 |
+
" </tr>\n",
|
| 232 |
+
" <tr>\n",
|
| 233 |
+
" <th>1</th>\n",
|
| 234 |
+
" <td>NDY-004</td>\n",
|
| 235 |
+
" <td>2</td>\n",
|
| 236 |
+
" <td>ihr singt voll gut :)</td>\n",
|
| 237 |
+
" </tr>\n",
|
| 238 |
+
" <tr>\n",
|
| 239 |
+
" <th>2</th>\n",
|
| 240 |
+
" <td>NDY-004</td>\n",
|
| 241 |
+
" <td>3</td>\n",
|
| 242 |
+
" <td>Junge fick dich</td>\n",
|
| 243 |
+
" </tr>\n",
|
| 244 |
+
" <tr>\n",
|
| 245 |
+
" <th>3</th>\n",
|
| 246 |
+
" <td>NDY-004</td>\n",
|
| 247 |
+
" <td>4</td>\n",
|
| 248 |
+
" <td>Ihr seit die besten</td>\n",
|
| 249 |
+
" </tr>\n",
|
| 250 |
+
" <tr>\n",
|
| 251 |
+
" <th>4</th>\n",
|
| 252 |
+
" <td>NDY-004</td>\n",
|
| 253 |
+
" <td>5</td>\n",
|
| 254 |
+
" <td>ihr seit die ALLER besten ich finde euch soooo...</td>\n",
|
| 255 |
+
" </tr>\n",
|
| 256 |
+
" <tr>\n",
|
| 257 |
+
" <th>...</th>\n",
|
| 258 |
+
" <td>...</td>\n",
|
| 259 |
+
" <td>...</td>\n",
|
| 260 |
+
" <td>...</td>\n",
|
| 261 |
+
" </tr>\n",
|
| 262 |
+
" <tr>\n",
|
| 263 |
+
" <th>9224</th>\n",
|
| 264 |
+
" <td>NDY-203</td>\n",
|
| 265 |
+
" <td>522</td>\n",
|
| 266 |
+
" <td>hihi kannst du mich grüßen 💕 👋 😍 Achso wusstes...</td>\n",
|
| 267 |
+
" </tr>\n",
|
| 268 |
+
" <tr>\n",
|
| 269 |
+
" <th>9225</th>\n",
|
| 270 |
+
" <td>NDY-203</td>\n",
|
| 271 |
+
" <td>523</td>\n",
|
| 272 |
+
" <td>#Glocke aktiviert 👑 Ich liebe deine Videos 💍 💎...</td>\n",
|
| 273 |
+
" </tr>\n",
|
| 274 |
+
" <tr>\n",
|
| 275 |
+
" <th>9226</th>\n",
|
| 276 |
+
" <td>NDY-203</td>\n",
|
| 277 |
+
" <td>524</td>\n",
|
| 278 |
+
" <td>Bist die beste ❤ Bitte Grüße mich 💕 ❤ 😘 😍</td>\n",
|
| 279 |
+
" </tr>\n",
|
| 280 |
+
" <tr>\n",
|
| 281 |
+
" <th>9227</th>\n",
|
| 282 |
+
" <td>NDY-203</td>\n",
|
| 283 |
+
" <td>525</td>\n",
|
| 284 |
+
" <td>Hi Bonny ❤️ War letztens auf'm Flughafen , und...</td>\n",
|
| 285 |
+
" </tr>\n",
|
| 286 |
+
" <tr>\n",
|
| 287 |
+
" <th>9228</th>\n",
|
| 288 |
+
" <td>NDY-203</td>\n",
|
| 289 |
+
" <td>526</td>\n",
|
| 290 |
+
" <td>du bist die beste ich bin neu ich hab dich sof...</td>\n",
|
| 291 |
+
" </tr>\n",
|
| 292 |
+
" </tbody>\n",
|
| 293 |
+
"</table>\n",
|
| 294 |
+
"<p>9229 rows × 3 columns</p>\n",
|
| 295 |
+
"</div>"
|
| 296 |
+
]
|
| 297 |
+
},
|
| 298 |
+
"execution_count": 3,
|
| 299 |
+
"metadata": {},
|
| 300 |
+
"output_type": "execute_result"
|
| 301 |
+
}
|
| 302 |
+
],
|
| 303 |
+
"execution_count": 3
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"metadata": {
|
| 307 |
+
"ExecuteTime": {
|
| 308 |
+
"end_time": "2025-06-23T19:22:07.211246Z",
|
| 309 |
+
"start_time": "2025-06-23T19:17:34.390901Z"
|
| 310 |
+
}
|
| 311 |
+
},
|
| 312 |
+
"cell_type": "code",
|
| 313 |
+
"source": "X_test_data = embedder.encode_batch(test_data['comment'].tolist(), batch_size=64)",
|
| 314 |
+
"id": "b2f18769fe09b609",
|
| 315 |
+
"outputs": [],
|
| 316 |
+
"execution_count": 6
|
| 317 |
+
},
|
| 318 |
+
{
|
| 319 |
+
"metadata": {
|
| 320 |
+
"ExecuteTime": {
|
| 321 |
+
"end_time": "2025-06-23T19:25:42.858436Z",
|
| 322 |
+
"start_time": "2025-06-23T19:22:07.287233Z"
|
| 323 |
+
}
|
| 324 |
+
},
|
| 325 |
+
"cell_type": "code",
|
| 326 |
+
"source": "y_prediction = pipe.predict(X_test_data)",
|
| 327 |
+
"id": "3a7abacf1694b415",
|
| 328 |
+
"outputs": [],
|
| 329 |
+
"execution_count": 7
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"metadata": {
|
| 333 |
+
"ExecuteTime": {
|
| 334 |
+
"end_time": "2025-06-23T19:31:30.676051Z",
|
| 335 |
+
"start_time": "2025-06-23T19:31:30.667660Z"
|
| 336 |
+
}
|
| 337 |
+
},
|
| 338 |
+
"cell_type": "code",
|
| 339 |
+
"source": [
|
| 340 |
+
"test_data['flausch'] = y_prediction\n",
|
| 341 |
+
"test_data['flausch'] = test_data['flausch'].map({1: 'yes', 0: 'no'})\n",
|
| 342 |
+
"test_data"
|
| 343 |
+
],
|
| 344 |
+
"id": "d342aed9b9070ad4",
|
| 345 |
+
"outputs": [
|
| 346 |
+
{
|
| 347 |
+
"data": {
|
| 348 |
+
"text/plain": [
|
| 349 |
+
" document comment_id comment \\\n",
|
| 350 |
+
"0 NDY-004 1 Lol i love lochis \n",
|
| 351 |
+
"1 NDY-004 2 ihr singt voll gut :) \n",
|
| 352 |
+
"2 NDY-004 3 Junge fick dich \n",
|
| 353 |
+
"3 NDY-004 4 Ihr seit die besten \n",
|
| 354 |
+
"4 NDY-004 5 ihr seit die ALLER besten ich finde euch soooo... \n",
|
| 355 |
+
"... ... ... ... \n",
|
| 356 |
+
"9224 NDY-203 522 hihi kannst du mich grüßen 💕 👋 😍 Achso wusstes... \n",
|
| 357 |
+
"9225 NDY-203 523 #Glocke aktiviert 👑 Ich liebe deine Videos 💍 💎... \n",
|
| 358 |
+
"9226 NDY-203 524 Bist die beste ❤ Bitte Grüße mich 💕 ❤ 😘 😍 \n",
|
| 359 |
+
"9227 NDY-203 525 Hi Bonny ❤️ War letztens auf'm Flughafen , und... \n",
|
| 360 |
+
"9228 NDY-203 526 du bist die beste ich bin neu ich hab dich sof... \n",
|
| 361 |
+
"\n",
|
| 362 |
+
" flausch \n",
|
| 363 |
+
"0 no \n",
|
| 364 |
+
"1 yes \n",
|
| 365 |
+
"2 no \n",
|
| 366 |
+
"3 yes \n",
|
| 367 |
+
"4 yes \n",
|
| 368 |
+
"... ... \n",
|
| 369 |
+
"9224 no \n",
|
| 370 |
+
"9225 yes \n",
|
| 371 |
+
"9226 yes \n",
|
| 372 |
+
"9227 yes \n",
|
| 373 |
+
"9228 yes \n",
|
| 374 |
+
"\n",
|
| 375 |
+
"[9229 rows x 4 columns]"
|
| 376 |
+
],
|
| 377 |
+
"text/html": [
|
| 378 |
+
"<div>\n",
|
| 379 |
+
"<style scoped>\n",
|
| 380 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 381 |
+
" vertical-align: middle;\n",
|
| 382 |
+
" }\n",
|
| 383 |
+
"\n",
|
| 384 |
+
" .dataframe tbody tr th {\n",
|
| 385 |
+
" vertical-align: top;\n",
|
| 386 |
+
" }\n",
|
| 387 |
+
"\n",
|
| 388 |
+
" .dataframe thead th {\n",
|
| 389 |
+
" text-align: right;\n",
|
| 390 |
+
" }\n",
|
| 391 |
+
"</style>\n",
|
| 392 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 393 |
+
" <thead>\n",
|
| 394 |
+
" <tr style=\"text-align: right;\">\n",
|
| 395 |
+
" <th></th>\n",
|
| 396 |
+
" <th>document</th>\n",
|
| 397 |
+
" <th>comment_id</th>\n",
|
| 398 |
+
" <th>comment</th>\n",
|
| 399 |
+
" <th>flausch</th>\n",
|
| 400 |
+
" </tr>\n",
|
| 401 |
+
" </thead>\n",
|
| 402 |
+
" <tbody>\n",
|
| 403 |
+
" <tr>\n",
|
| 404 |
+
" <th>0</th>\n",
|
| 405 |
+
" <td>NDY-004</td>\n",
|
| 406 |
+
" <td>1</td>\n",
|
| 407 |
+
" <td>Lol i love lochis</td>\n",
|
| 408 |
+
" <td>no</td>\n",
|
| 409 |
+
" </tr>\n",
|
| 410 |
+
" <tr>\n",
|
| 411 |
+
" <th>1</th>\n",
|
| 412 |
+
" <td>NDY-004</td>\n",
|
| 413 |
+
" <td>2</td>\n",
|
| 414 |
+
" <td>ihr singt voll gut :)</td>\n",
|
| 415 |
+
" <td>yes</td>\n",
|
| 416 |
+
" </tr>\n",
|
| 417 |
+
" <tr>\n",
|
| 418 |
+
" <th>2</th>\n",
|
| 419 |
+
" <td>NDY-004</td>\n",
|
| 420 |
+
" <td>3</td>\n",
|
| 421 |
+
" <td>Junge fick dich</td>\n",
|
| 422 |
+
" <td>no</td>\n",
|
| 423 |
+
" </tr>\n",
|
| 424 |
+
" <tr>\n",
|
| 425 |
+
" <th>3</th>\n",
|
| 426 |
+
" <td>NDY-004</td>\n",
|
| 427 |
+
" <td>4</td>\n",
|
| 428 |
+
" <td>Ihr seit die besten</td>\n",
|
| 429 |
+
" <td>yes</td>\n",
|
| 430 |
+
" </tr>\n",
|
| 431 |
+
" <tr>\n",
|
| 432 |
+
" <th>4</th>\n",
|
| 433 |
+
" <td>NDY-004</td>\n",
|
| 434 |
+
" <td>5</td>\n",
|
| 435 |
+
" <td>ihr seit die ALLER besten ich finde euch soooo...</td>\n",
|
| 436 |
+
" <td>yes</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <th>...</th>\n",
|
| 440 |
+
" <td>...</td>\n",
|
| 441 |
+
" <td>...</td>\n",
|
| 442 |
+
" <td>...</td>\n",
|
| 443 |
+
" <td>...</td>\n",
|
| 444 |
+
" </tr>\n",
|
| 445 |
+
" <tr>\n",
|
| 446 |
+
" <th>9224</th>\n",
|
| 447 |
+
" <td>NDY-203</td>\n",
|
| 448 |
+
" <td>522</td>\n",
|
| 449 |
+
" <td>hihi kannst du mich grüßen 💕 👋 😍 Achso wusstes...</td>\n",
|
| 450 |
+
" <td>no</td>\n",
|
| 451 |
+
" </tr>\n",
|
| 452 |
+
" <tr>\n",
|
| 453 |
+
" <th>9225</th>\n",
|
| 454 |
+
" <td>NDY-203</td>\n",
|
| 455 |
+
" <td>523</td>\n",
|
| 456 |
+
" <td>#Glocke aktiviert 👑 Ich liebe deine Videos 💍 💎...</td>\n",
|
| 457 |
+
" <td>yes</td>\n",
|
| 458 |
+
" </tr>\n",
|
| 459 |
+
" <tr>\n",
|
| 460 |
+
" <th>9226</th>\n",
|
| 461 |
+
" <td>NDY-203</td>\n",
|
| 462 |
+
" <td>524</td>\n",
|
| 463 |
+
" <td>Bist die beste ❤ Bitte Grüße mich 💕 ❤ 😘 😍</td>\n",
|
| 464 |
+
" <td>yes</td>\n",
|
| 465 |
+
" </tr>\n",
|
| 466 |
+
" <tr>\n",
|
| 467 |
+
" <th>9227</th>\n",
|
| 468 |
+
" <td>NDY-203</td>\n",
|
| 469 |
+
" <td>525</td>\n",
|
| 470 |
+
" <td>Hi Bonny ❤️ War letztens auf'm Flughafen , und...</td>\n",
|
| 471 |
+
" <td>yes</td>\n",
|
| 472 |
+
" </tr>\n",
|
| 473 |
+
" <tr>\n",
|
| 474 |
+
" <th>9228</th>\n",
|
| 475 |
+
" <td>NDY-203</td>\n",
|
| 476 |
+
" <td>526</td>\n",
|
| 477 |
+
" <td>du bist die beste ich bin neu ich hab dich sof...</td>\n",
|
| 478 |
+
" <td>yes</td>\n",
|
| 479 |
+
" </tr>\n",
|
| 480 |
+
" </tbody>\n",
|
| 481 |
+
"</table>\n",
|
| 482 |
+
"<p>9229 rows × 4 columns</p>\n",
|
| 483 |
+
"</div>"
|
| 484 |
+
]
|
| 485 |
+
},
|
| 486 |
+
"execution_count": 11,
|
| 487 |
+
"metadata": {},
|
| 488 |
+
"output_type": "execute_result"
|
| 489 |
+
}
|
| 490 |
+
],
|
| 491 |
+
"execution_count": 11
|
| 492 |
+
},
|
| 493 |
+
{
|
| 494 |
+
"metadata": {
|
| 495 |
+
"ExecuteTime": {
|
| 496 |
+
"end_time": "2025-06-23T19:33:51.519362Z",
|
| 497 |
+
"start_time": "2025-06-23T19:33:51.512704Z"
|
| 498 |
+
}
|
| 499 |
+
},
|
| 500 |
+
"cell_type": "code",
|
| 501 |
+
"source": "test_data[['document', 'comment_id', 'flausch']]",
|
| 502 |
+
"id": "ac4077f355d0a379",
|
| 503 |
+
"outputs": [
|
| 504 |
+
{
|
| 505 |
+
"data": {
|
| 506 |
+
"text/plain": [
|
| 507 |
+
" document comment_id flausch\n",
|
| 508 |
+
"0 NDY-004 1 no\n",
|
| 509 |
+
"1 NDY-004 2 yes\n",
|
| 510 |
+
"2 NDY-004 3 no\n",
|
| 511 |
+
"3 NDY-004 4 yes\n",
|
| 512 |
+
"4 NDY-004 5 yes\n",
|
| 513 |
+
"... ... ... ...\n",
|
| 514 |
+
"9224 NDY-203 522 no\n",
|
| 515 |
+
"9225 NDY-203 523 yes\n",
|
| 516 |
+
"9226 NDY-203 524 yes\n",
|
| 517 |
+
"9227 NDY-203 525 yes\n",
|
| 518 |
+
"9228 NDY-203 526 yes\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"[9229 rows x 3 columns]"
|
| 521 |
+
],
|
| 522 |
+
"text/html": [
|
| 523 |
+
"<div>\n",
|
| 524 |
+
"<style scoped>\n",
|
| 525 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 526 |
+
" vertical-align: middle;\n",
|
| 527 |
+
" }\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" .dataframe tbody tr th {\n",
|
| 530 |
+
" vertical-align: top;\n",
|
| 531 |
+
" }\n",
|
| 532 |
+
"\n",
|
| 533 |
+
" .dataframe thead th {\n",
|
| 534 |
+
" text-align: right;\n",
|
| 535 |
+
" }\n",
|
| 536 |
+
"</style>\n",
|
| 537 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 538 |
+
" <thead>\n",
|
| 539 |
+
" <tr style=\"text-align: right;\">\n",
|
| 540 |
+
" <th></th>\n",
|
| 541 |
+
" <th>document</th>\n",
|
| 542 |
+
" <th>comment_id</th>\n",
|
| 543 |
+
" <th>flausch</th>\n",
|
| 544 |
+
" </tr>\n",
|
| 545 |
+
" </thead>\n",
|
| 546 |
+
" <tbody>\n",
|
| 547 |
+
" <tr>\n",
|
| 548 |
+
" <th>0</th>\n",
|
| 549 |
+
" <td>NDY-004</td>\n",
|
| 550 |
+
" <td>1</td>\n",
|
| 551 |
+
" <td>no</td>\n",
|
| 552 |
+
" </tr>\n",
|
| 553 |
+
" <tr>\n",
|
| 554 |
+
" <th>1</th>\n",
|
| 555 |
+
" <td>NDY-004</td>\n",
|
| 556 |
+
" <td>2</td>\n",
|
| 557 |
+
" <td>yes</td>\n",
|
| 558 |
+
" </tr>\n",
|
| 559 |
+
" <tr>\n",
|
| 560 |
+
" <th>2</th>\n",
|
| 561 |
+
" <td>NDY-004</td>\n",
|
| 562 |
+
" <td>3</td>\n",
|
| 563 |
+
" <td>no</td>\n",
|
| 564 |
+
" </tr>\n",
|
| 565 |
+
" <tr>\n",
|
| 566 |
+
" <th>3</th>\n",
|
| 567 |
+
" <td>NDY-004</td>\n",
|
| 568 |
+
" <td>4</td>\n",
|
| 569 |
+
" <td>yes</td>\n",
|
| 570 |
+
" </tr>\n",
|
| 571 |
+
" <tr>\n",
|
| 572 |
+
" <th>4</th>\n",
|
| 573 |
+
" <td>NDY-004</td>\n",
|
| 574 |
+
" <td>5</td>\n",
|
| 575 |
+
" <td>yes</td>\n",
|
| 576 |
+
" </tr>\n",
|
| 577 |
+
" <tr>\n",
|
| 578 |
+
" <th>...</th>\n",
|
| 579 |
+
" <td>...</td>\n",
|
| 580 |
+
" <td>...</td>\n",
|
| 581 |
+
" <td>...</td>\n",
|
| 582 |
+
" </tr>\n",
|
| 583 |
+
" <tr>\n",
|
| 584 |
+
" <th>9224</th>\n",
|
| 585 |
+
" <td>NDY-203</td>\n",
|
| 586 |
+
" <td>522</td>\n",
|
| 587 |
+
" <td>no</td>\n",
|
| 588 |
+
" </tr>\n",
|
| 589 |
+
" <tr>\n",
|
| 590 |
+
" <th>9225</th>\n",
|
| 591 |
+
" <td>NDY-203</td>\n",
|
| 592 |
+
" <td>523</td>\n",
|
| 593 |
+
" <td>yes</td>\n",
|
| 594 |
+
" </tr>\n",
|
| 595 |
+
" <tr>\n",
|
| 596 |
+
" <th>9226</th>\n",
|
| 597 |
+
" <td>NDY-203</td>\n",
|
| 598 |
+
" <td>524</td>\n",
|
| 599 |
+
" <td>yes</td>\n",
|
| 600 |
+
" </tr>\n",
|
| 601 |
+
" <tr>\n",
|
| 602 |
+
" <th>9227</th>\n",
|
| 603 |
+
" <td>NDY-203</td>\n",
|
| 604 |
+
" <td>525</td>\n",
|
| 605 |
+
" <td>yes</td>\n",
|
| 606 |
+
" </tr>\n",
|
| 607 |
+
" <tr>\n",
|
| 608 |
+
" <th>9228</th>\n",
|
| 609 |
+
" <td>NDY-203</td>\n",
|
| 610 |
+
" <td>526</td>\n",
|
| 611 |
+
" <td>yes</td>\n",
|
| 612 |
+
" </tr>\n",
|
| 613 |
+
" </tbody>\n",
|
| 614 |
+
"</table>\n",
|
| 615 |
+
"<p>9229 rows × 3 columns</p>\n",
|
| 616 |
+
"</div>"
|
| 617 |
+
]
|
| 618 |
+
},
|
| 619 |
+
"execution_count": 12,
|
| 620 |
+
"metadata": {},
|
| 621 |
+
"output_type": "execute_result"
|
| 622 |
+
}
|
| 623 |
+
],
|
| 624 |
+
"execution_count": 12
|
| 625 |
+
},
|
| 626 |
+
{
|
| 627 |
+
"metadata": {
|
| 628 |
+
"ExecuteTime": {
|
| 629 |
+
"end_time": "2025-06-23T19:34:57.446239Z",
|
| 630 |
+
"start_time": "2025-06-23T19:34:57.431741Z"
|
| 631 |
+
}
|
| 632 |
+
},
|
| 633 |
+
"cell_type": "code",
|
| 634 |
+
"source": "test_data[['document', 'comment_id', 'flausch']].to_csv(f'./submissions/subtask1_submission1.csv', index=False)",
|
| 635 |
+
"id": "ce927f8936231813",
|
| 636 |
+
"outputs": [],
|
| 637 |
+
"execution_count": 16
|
| 638 |
+
},
|
| 639 |
+
{
|
| 640 |
+
"metadata": {
|
| 641 |
+
"ExecuteTime": {
|
| 642 |
+
"end_time": "2025-06-23T19:37:22.875657Z",
|
| 643 |
+
"start_time": "2025-06-23T19:37:22.653931Z"
|
| 644 |
+
}
|
| 645 |
+
},
|
| 646 |
+
"cell_type": "code",
|
| 647 |
+
"source": "!head -n 10 './submissions/subtask1_submission1.csv'",
|
| 648 |
+
"id": "e358ae2660d91769",
|
| 649 |
+
"outputs": [
|
| 650 |
+
{
|
| 651 |
+
"name": "stdout",
|
| 652 |
+
"output_type": "stream",
|
| 653 |
+
"text": [
|
| 654 |
+
"document,comment_id,flausch\r\n",
|
| 655 |
+
"NDY-004,1,no\r\n",
|
| 656 |
+
"NDY-004,2,yes\r\n",
|
| 657 |
+
"NDY-004,3,no\r\n",
|
| 658 |
+
"NDY-004,4,yes\r\n",
|
| 659 |
+
"NDY-004,5,yes\r\n",
|
| 660 |
+
"NDY-004,6,yes\r\n",
|
| 661 |
+
"NDY-004,7,no\r\n",
|
| 662 |
+
"NDY-004,8,no\r\n",
|
| 663 |
+
"NDY-004,9,no\r\n"
|
| 664 |
+
]
|
| 665 |
+
},
|
| 666 |
+
{
|
| 667 |
+
"name": "stderr",
|
| 668 |
+
"output_type": "stream",
|
| 669 |
+
"text": [
|
| 670 |
+
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
| 671 |
+
"To disable this warning, you can either:\n",
|
| 672 |
+
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
| 673 |
+
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
| 674 |
+
]
|
| 675 |
+
}
|
| 676 |
+
],
|
| 677 |
+
"execution_count": 19
|
| 678 |
+
},
|
| 679 |
+
{
|
| 680 |
+
"metadata": {},
|
| 681 |
+
"cell_type": "code",
|
| 682 |
+
"source": "!cp './submissions/subtask1_submission1.csv' './submissions/task1-predicted.csv'",
|
| 683 |
+
"id": "e820c01a833df1db",
|
| 684 |
+
"outputs": [],
|
| 685 |
+
"execution_count": null
|
| 686 |
+
},
|
| 687 |
+
{
|
| 688 |
+
"metadata": {},
|
| 689 |
+
"cell_type": "markdown",
|
| 690 |
+
"source": [
|
| 691 |
+
" Score für Subtask 1:\n",
|
| 692 |
+
"\n",
|
| 693 |
+
" → 0.88"
|
| 694 |
+
],
|
| 695 |
+
"id": "c441568bcdde6462"
|
| 696 |
+
}
|
| 697 |
+
],
|
| 698 |
+
"metadata": {
|
| 699 |
+
"kernelspec": {
|
| 700 |
+
"display_name": "Python 3",
|
| 701 |
+
"language": "python",
|
| 702 |
+
"name": "python3"
|
| 703 |
+
},
|
| 704 |
+
"language_info": {
|
| 705 |
+
"codemirror_mode": {
|
| 706 |
+
"name": "ipython",
|
| 707 |
+
"version": 2
|
| 708 |
+
},
|
| 709 |
+
"file_extension": ".py",
|
| 710 |
+
"mimetype": "text/x-python",
|
| 711 |
+
"name": "python",
|
| 712 |
+
"nbconvert_exporter": "python",
|
| 713 |
+
"pygments_lexer": "ipython2",
|
| 714 |
+
"version": "2.7.6"
|
| 715 |
+
}
|
| 716 |
+
},
|
| 717 |
+
"nbformat": 4,
|
| 718 |
+
"nbformat_minor": 5
|
| 719 |
+
}
|
subtask_2/exp027-1.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
import wandb
|
| 10 |
+
from datasets import Dataset
|
| 11 |
+
from multiset import *
|
| 12 |
+
from sklearn.model_selection import train_test_split, StratifiedKFold
|
| 13 |
+
from transformers import (
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
AutoModelForTokenClassification,
|
| 16 |
+
TrainingArguments,
|
| 17 |
+
Trainer,
|
| 18 |
+
DataCollatorForTokenClassification,
|
| 19 |
+
EarlyStoppingCallback
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 23 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
| 24 |
+
|
| 25 |
+
os.environ["WANDB_PROJECT"]="GermEval2025-Substask2"
|
| 26 |
+
os.environ["WANDB_LOG_MODEL"]="false"
|
| 27 |
+
|
| 28 |
+
experiment_name = 'exp027-1'
|
| 29 |
+
|
| 30 |
+
ALL_LABELS = ["affection declaration","agreement","ambiguous",
|
| 31 |
+
"compliment","encouragement","gratitude","group membership",
|
| 32 |
+
"implicit","positive feedback","sympathy"]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def fine_grained_flausch_by_label(gold, predicted):
|
| 36 |
+
gold['cid']= gold['document']+"_"+gold['comment_id'].apply(str)
|
| 37 |
+
predicted['cid']= predicted['document']+"_"+predicted['comment_id'].apply(str)
|
| 38 |
+
|
| 39 |
+
# annotation sets (predicted)
|
| 40 |
+
pred_spans = Multiset()
|
| 41 |
+
pred_spans_loose = Multiset()
|
| 42 |
+
pred_types = Multiset()
|
| 43 |
+
|
| 44 |
+
# annotation sets (gold)
|
| 45 |
+
gold_spans = Multiset()
|
| 46 |
+
gold_spans_loose = Multiset()
|
| 47 |
+
gold_types = Multiset()
|
| 48 |
+
|
| 49 |
+
for row in predicted.itertuples(index=False):
|
| 50 |
+
pred_spans.add((row.cid,row.type,row.start,row.end))
|
| 51 |
+
pred_spans_loose.add((row.cid,row.start,row.end))
|
| 52 |
+
pred_types.add((row.cid,row.type))
|
| 53 |
+
for row in gold.itertuples(index=False):
|
| 54 |
+
gold_spans.add((row.cid,row.type,row.start,row.end))
|
| 55 |
+
gold_spans_loose.add((row.cid,row.start,row.end))
|
| 56 |
+
gold_types.add((row.cid,row.type))
|
| 57 |
+
|
| 58 |
+
# precision = true_pos / true_pos + false_pos
|
| 59 |
+
# recall = true_pos / true_pos + false_neg
|
| 60 |
+
# f_1 = 2 * prec * rec / (prec + rec)
|
| 61 |
+
|
| 62 |
+
results = {'TOTAL': {'STRICT': {},'SPANS': {},'TYPES': {}}}
|
| 63 |
+
# label-wise evaluation (only for strict and type)
|
| 64 |
+
for label in ALL_LABELS:
|
| 65 |
+
results[label] = {'STRICT': {},'TYPES': {}}
|
| 66 |
+
gold_spans_x = set(filter(lambda x: x[1].__eq__(label), gold_spans))
|
| 67 |
+
pred_spans_x = set(filter(lambda x: x[1].__eq__(label), pred_spans))
|
| 68 |
+
gold_types_x = set(filter(lambda x: x[1].__eq__(label), gold_types))
|
| 69 |
+
pred_types_x = set(filter(lambda x: x[1].__eq__(label), pred_types))
|
| 70 |
+
|
| 71 |
+
# strict: spans + type must match
|
| 72 |
+
### NOTE: x and y / x returns 0 if x = 0 and y/x otherwise (test for zero division)
|
| 73 |
+
strict_p = float(len(pred_spans_x)) and float( len(gold_spans_x.intersection(pred_spans_x))) / len(pred_spans_x)
|
| 74 |
+
strict_r = float(len(gold_spans_x)) and float( len(gold_spans_x.intersection(pred_spans_x))) / len(gold_spans_x)
|
| 75 |
+
strict_f = (strict_p + strict_r) and 2 * strict_p * strict_r / (strict_p + strict_r)
|
| 76 |
+
results[label]['STRICT']['prec'] = strict_p
|
| 77 |
+
results[label]['STRICT']['rec'] = strict_r
|
| 78 |
+
results[label]['STRICT']['f1'] = strict_f
|
| 79 |
+
|
| 80 |
+
# detection mode: only types must match (per post)
|
| 81 |
+
types_p = float(len(pred_types_x)) and float( len(gold_types_x.intersection(pred_types_x))) / len(pred_types_x)
|
| 82 |
+
types_r = float(len(gold_types_x)) and float( len(gold_types_x.intersection(pred_types_x))) / len(gold_types_x)
|
| 83 |
+
types_f = (types_p + types_r) and 2 * types_p * types_r / (types_p + types_r)
|
| 84 |
+
results[label]['TYPES']['prec'] = types_p
|
| 85 |
+
results[label]['TYPES']['rec'] = types_r
|
| 86 |
+
results[label]['TYPES']['f1'] = types_f
|
| 87 |
+
|
| 88 |
+
# Overall evaluation
|
| 89 |
+
# strict: spans + type must match
|
| 90 |
+
strict_p = float(len(pred_spans)) and float( len(gold_spans.intersection(pred_spans))) / len(pred_spans)
|
| 91 |
+
strict_r = float(len(gold_spans)) and float( len(gold_spans.intersection(pred_spans))) / len(gold_spans)
|
| 92 |
+
strict_f = (strict_p + strict_r) and 2 * strict_p * strict_r / (strict_p + strict_r)
|
| 93 |
+
results['TOTAL']['STRICT']['prec'] = strict_p
|
| 94 |
+
results['TOTAL']['STRICT']['rec'] = strict_r
|
| 95 |
+
results['TOTAL']['STRICT']['f1'] = strict_f
|
| 96 |
+
|
| 97 |
+
# spans: spans must match
|
| 98 |
+
spans_p = float(len(pred_spans_loose)) and float( len(gold_spans_loose.intersection(pred_spans_loose))) / len(pred_spans_loose)
|
| 99 |
+
spans_r = float(len(gold_spans_loose)) and float( len(gold_spans_loose.intersection(pred_spans_loose))) / len(gold_spans_loose)
|
| 100 |
+
spans_f = (spans_p + spans_r) and 2 * spans_p * spans_r / (spans_p + spans_r)
|
| 101 |
+
results['TOTAL']['SPANS']['prec'] = spans_p
|
| 102 |
+
results['TOTAL']['SPANS']['rec'] = spans_r
|
| 103 |
+
results['TOTAL']['SPANS']['f1'] = spans_f
|
| 104 |
+
|
| 105 |
+
# detection mode: only types must match (per post)
|
| 106 |
+
types_p = float(len(pred_types)) and float( len(gold_types.intersection(pred_types))) / len(pred_types)
|
| 107 |
+
types_r = float(len(gold_types)) and float( len(gold_types.intersection(pred_types))) / len(gold_types)
|
| 108 |
+
types_f = (types_p + types_r) and 2 * types_p * types_r / (types_p + types_r)
|
| 109 |
+
results['TOTAL']['TYPES']['prec'] = types_p
|
| 110 |
+
results['TOTAL']['TYPES']['rec'] = types_r
|
| 111 |
+
results['TOTAL']['TYPES']['f1'] = types_f
|
| 112 |
+
|
| 113 |
+
return results
|
| 114 |
+
|
| 115 |
+
class SpanClassifierWithStrictF1:
|
| 116 |
+
def __init__(self, model_name="deepset/gbert-base"):
|
| 117 |
+
self.model_name = model_name
|
| 118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 119 |
+
|
| 120 |
+
self.labels =[
|
| 121 |
+
"O",
|
| 122 |
+
"B-positive feedback", "B-compliment", "B-affection declaration", "B-encouragement", "B-gratitude", "B-agreement", "B-ambiguous", "B-implicit", "B-group membership", "B-sympathy",
|
| 123 |
+
"I-positive feedback", "I-compliment", "I-affection declaration", "I-encouragement", "I-gratitude", "I-agreement", "I-ambiguous", "I-implicit", "I-group membership", "I-sympathy"
|
| 124 |
+
]
|
| 125 |
+
self.label2id = {label: i for i, label in enumerate(self.labels)}
|
| 126 |
+
self.id2label = {i: label for i, label in enumerate(self.labels)}
|
| 127 |
+
|
| 128 |
+
def create_dataset(self, comments_df, spans_df):
|
| 129 |
+
"""Erstelle Dataset mit BIO-Labels und speichere Evaluation-Daten"""
|
| 130 |
+
examples = []
|
| 131 |
+
eval_data = [] # Für Strict F1 Berechnung
|
| 132 |
+
|
| 133 |
+
spans_grouped = spans_df.groupby(['document', 'comment_id'])
|
| 134 |
+
|
| 135 |
+
for _, row in comments_df.iterrows():
|
| 136 |
+
text = row['comment']
|
| 137 |
+
document = row['document']
|
| 138 |
+
comment_id = row['comment_id']
|
| 139 |
+
key = (document, comment_id)
|
| 140 |
+
|
| 141 |
+
# True spans für diesen Kommentar
|
| 142 |
+
if key in spans_grouped.groups:
|
| 143 |
+
true_spans = [(span_type, int(start), int(end))
|
| 144 |
+
for span_type, start, end in
|
| 145 |
+
spans_grouped.get_group(key)[['type', 'start', 'end']].values]
|
| 146 |
+
else:
|
| 147 |
+
true_spans = []
|
| 148 |
+
|
| 149 |
+
# Tokenisierung
|
| 150 |
+
tokenized = self.tokenizer(text, truncation=True, max_length=512,
|
| 151 |
+
return_offsets_mapping=True)
|
| 152 |
+
|
| 153 |
+
# BIO-Labels erstellen
|
| 154 |
+
labels = self._create_bio_labels(tokenized['offset_mapping'],
|
| 155 |
+
spans_grouped.get_group(key)[['start', 'end', 'type']].values
|
| 156 |
+
if key in spans_grouped.groups else [])
|
| 157 |
+
|
| 158 |
+
examples.append({
|
| 159 |
+
'input_ids': tokenized['input_ids'],
|
| 160 |
+
'attention_mask': tokenized['attention_mask'],
|
| 161 |
+
'labels': labels
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
# Evaluation-Daten speichern
|
| 165 |
+
eval_data.append({
|
| 166 |
+
'text': text,
|
| 167 |
+
'offset_mapping': tokenized['offset_mapping'],
|
| 168 |
+
'true_spans': true_spans,
|
| 169 |
+
'document': document,
|
| 170 |
+
'comment_id': comment_id
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
return examples, eval_data
|
| 174 |
+
|
| 175 |
+
def _create_bio_labels(self, offset_mapping, spans):
|
| 176 |
+
"""Erstelle BIO-Labels für Tokens"""
|
| 177 |
+
labels = [0] * len(offset_mapping) # 0 = "O"
|
| 178 |
+
|
| 179 |
+
for start, end, type_label in spans:
|
| 180 |
+
for i, (token_start, token_end) in enumerate(offset_mapping):
|
| 181 |
+
if token_start is None: # Spezielle Tokens
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
# Token überlappt mit Span
|
| 185 |
+
if token_start < end and token_end > start:
|
| 186 |
+
if token_start <= start:
|
| 187 |
+
if labels[i] != 0:
|
| 188 |
+
# dont overwrite labels if spans are overlapping; just skip the span
|
| 189 |
+
break
|
| 190 |
+
labels[i] = self.label2id[f'B-{type_label}'] # B-compliment
|
| 191 |
+
else:
|
| 192 |
+
labels[i] = self.label2id[f'I-{type_label}'] # I-compliment
|
| 193 |
+
|
| 194 |
+
return labels
|
| 195 |
+
|
| 196 |
+
def _predictions_to_dataframe(self, predictions_list, comments_df_subset):
|
| 197 |
+
"""Konvertiere Vorhersagen zu DataFrame für Flausch-Metrik"""
|
| 198 |
+
pred_data = []
|
| 199 |
+
|
| 200 |
+
for i, pred in enumerate(predictions_list):
|
| 201 |
+
if i < len(comments_df_subset):
|
| 202 |
+
row = comments_df_subset.iloc[i]
|
| 203 |
+
document = row['document']
|
| 204 |
+
comment_id = row['comment_id']
|
| 205 |
+
|
| 206 |
+
for span in pred['spans']:
|
| 207 |
+
pred_data.append({
|
| 208 |
+
'document': document,
|
| 209 |
+
'comment_id': comment_id,
|
| 210 |
+
'type': span['type'],
|
| 211 |
+
'start': span['start'],
|
| 212 |
+
'end': span['end']
|
| 213 |
+
})
|
| 214 |
+
|
| 215 |
+
return pd.DataFrame(pred_data)
|
| 216 |
+
|
| 217 |
+
# --- helper that builds a DataFrame of spans from eval data + predictions ---
|
| 218 |
+
def _build_span_dfs(self, eval_data, batch_pred_spans):
|
| 219 |
+
"""
|
| 220 |
+
eval_data: list of dicts with keys document, comment_id, true_spans
|
| 221 |
+
batch_pred_spans: list of lists of (type, start, end)
|
| 222 |
+
returns (gold_df, pred_df) suitable for fine_grained_flausch_by_label
|
| 223 |
+
"""
|
| 224 |
+
rows_gold = []
|
| 225 |
+
rows_pred = []
|
| 226 |
+
for item, pred_spans in zip(eval_data, batch_pred_spans):
|
| 227 |
+
doc = item['document']
|
| 228 |
+
cid = item['comment_id']
|
| 229 |
+
# gold
|
| 230 |
+
for t, s, e in item['true_spans']:
|
| 231 |
+
rows_gold.append({
|
| 232 |
+
'document': doc,
|
| 233 |
+
'comment_id': cid,
|
| 234 |
+
'type': t,
|
| 235 |
+
'start': s,
|
| 236 |
+
'end': e
|
| 237 |
+
})
|
| 238 |
+
# pred
|
| 239 |
+
for t, s, e in pred_spans:
|
| 240 |
+
rows_pred.append({
|
| 241 |
+
'document': doc,
|
| 242 |
+
'comment_id': cid,
|
| 243 |
+
'type': t,
|
| 244 |
+
'start': s,
|
| 245 |
+
'end': e
|
| 246 |
+
})
|
| 247 |
+
gold_df = pd.DataFrame(rows_gold, columns=['document','comment_id','type','start','end'])
|
| 248 |
+
pred_df = pd.DataFrame(rows_pred, columns=['document','comment_id','type','start','end'])
|
| 249 |
+
return gold_df, pred_df
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def compute_metrics(self, eval_pred):
|
| 253 |
+
"""
|
| 254 |
+
Called by the HF-Trainer at each evaluation step.
|
| 255 |
+
We collect batch predictions, reconstruct gold/pred spans,
|
| 256 |
+
call fine_grained_flausch_by_label and return the TOTAL/STRICT metrics.
|
| 257 |
+
"""
|
| 258 |
+
logits, labels = eval_pred
|
| 259 |
+
preds = np.argmax(logits, axis=2)
|
| 260 |
+
|
| 261 |
+
# reconstruct spans per example in this batch
|
| 262 |
+
batch_pred_spans = []
|
| 263 |
+
for i, (p_seq, lab_seq) in enumerate(zip(preds, labels)):
|
| 264 |
+
# skip padding (-100)
|
| 265 |
+
valid_preds = []
|
| 266 |
+
valid_offsets = []
|
| 267 |
+
offsets = self.current_eval_data[i]['offset_mapping']
|
| 268 |
+
for j,(p,l) in enumerate(zip(p_seq, lab_seq)):
|
| 269 |
+
if l != -100:
|
| 270 |
+
valid_preds.append(int(p))
|
| 271 |
+
valid_offsets.append(offsets[j])
|
| 272 |
+
# convert to spans
|
| 273 |
+
pred_spans = self._predictions_to_spans(valid_preds, valid_offsets,
|
| 274 |
+
self.current_eval_data[i]['text'])
|
| 275 |
+
# to (type, start, end)-tuples
|
| 276 |
+
batch_pred_spans.append([(sp['type'], sp['start'], sp['end'])
|
| 277 |
+
for sp in pred_spans])
|
| 278 |
+
|
| 279 |
+
# build the gold/pred DataFrames
|
| 280 |
+
gold_df, pred_df = self._build_span_dfs(self.current_eval_data,
|
| 281 |
+
batch_pred_spans)
|
| 282 |
+
|
| 283 |
+
# call your fine-grained metrics
|
| 284 |
+
results = fine_grained_flausch_by_label(gold_df, pred_df)
|
| 285 |
+
|
| 286 |
+
# extract the TOTAL/STRICT metrics
|
| 287 |
+
total = results['TOTAL']['STRICT']
|
| 288 |
+
return {
|
| 289 |
+
'strict_prec': torch.tensor(total['prec'], dtype=torch.float32),
|
| 290 |
+
'strict_rec': torch.tensor(total['rec'], dtype=torch.float32),
|
| 291 |
+
'strict_f1': torch.tensor(total['f1'], dtype=torch.float32),
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def evaluate_by_label(self, comments_df, spans_df):
|
| 296 |
+
"""
|
| 297 |
+
Replace evaluate_strict_f1. Runs a full pass over all comments,
|
| 298 |
+
uses self.predict() to get spans, then calls your fine_grained_flausch_by_label
|
| 299 |
+
and prints & returns the TOTAL metrics.
|
| 300 |
+
"""
|
| 301 |
+
# 1) run predictions
|
| 302 |
+
texts = comments_df['comment'].tolist()
|
| 303 |
+
docs = comments_df['document'].tolist()
|
| 304 |
+
cids = comments_df['comment_id'].tolist()
|
| 305 |
+
preds = self.predict(texts)
|
| 306 |
+
|
| 307 |
+
# 2) build gold and pred lists
|
| 308 |
+
gold_rows = []
|
| 309 |
+
for (_, row) in comments_df.iterrows():
|
| 310 |
+
key = (row['document'], row['comment_id'])
|
| 311 |
+
# get all true spans for this comment_id
|
| 312 |
+
group = spans_df[
|
| 313 |
+
(spans_df.document==row['document']) &
|
| 314 |
+
(spans_df.comment_id==row['comment_id'])
|
| 315 |
+
]
|
| 316 |
+
for _, sp in group.iterrows():
|
| 317 |
+
gold_rows.append({
|
| 318 |
+
'document': row['document'],
|
| 319 |
+
'comment_id': row['comment_id'],
|
| 320 |
+
'type': sp['type'],
|
| 321 |
+
'start': sp['start'],
|
| 322 |
+
'end': sp['end']
|
| 323 |
+
})
|
| 324 |
+
|
| 325 |
+
pred_rows = []
|
| 326 |
+
for doc, cid, p in zip(docs, cids, preds):
|
| 327 |
+
for sp in p['spans']:
|
| 328 |
+
pred_rows.append({
|
| 329 |
+
'document': doc,
|
| 330 |
+
'comment_id': cid,
|
| 331 |
+
'type': sp['type'],
|
| 332 |
+
'start': sp['start'],
|
| 333 |
+
'end': sp['end']
|
| 334 |
+
})
|
| 335 |
+
|
| 336 |
+
gold_df = pd.DataFrame(gold_rows, columns=['document','comment_id','type','start','end'])
|
| 337 |
+
pred_df = pd.DataFrame(pred_rows, columns=['document','comment_id','type','start','end'])
|
| 338 |
+
|
| 339 |
+
# 3) call fine-grained
|
| 340 |
+
results = fine_grained_flausch_by_label(gold_df, pred_df)
|
| 341 |
+
|
| 342 |
+
# 4) extract and print
|
| 343 |
+
total = results['TOTAL']
|
| 344 |
+
print("\n=== EVALUATION BY FLAUSCH METRICS ===")
|
| 345 |
+
for mode in ['STRICT','SPANS','TYPES']:
|
| 346 |
+
m = total[mode]
|
| 347 |
+
print(f"{mode:6} P={m['prec']:.4f} R={m['rec']:.4f} F1={m['f1']:.4f}")
|
| 348 |
+
|
| 349 |
+
return results
|
| 350 |
+
|
| 351 |
+
def _predictions_to_spans(self, predicted_labels, offset_mapping, text):
|
| 352 |
+
"""Konvertiere Token-Vorhersagen zu Spans"""
|
| 353 |
+
spans = []
|
| 354 |
+
current_span = None
|
| 355 |
+
|
| 356 |
+
for i, label_id in enumerate(predicted_labels):
|
| 357 |
+
if i >= len(offset_mapping):
|
| 358 |
+
break
|
| 359 |
+
|
| 360 |
+
label = self.id2label[label_id]
|
| 361 |
+
token_start, token_end = offset_mapping[i]
|
| 362 |
+
|
| 363 |
+
if token_start is None:
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
if label.startswith('B-'):
|
| 367 |
+
if current_span:
|
| 368 |
+
spans.append(current_span)
|
| 369 |
+
current_span = {
|
| 370 |
+
'type': label[2:],
|
| 371 |
+
'start': token_start,
|
| 372 |
+
'end': token_end,
|
| 373 |
+
'text': text[token_start:token_end]
|
| 374 |
+
}
|
| 375 |
+
elif label.startswith('I-') and current_span:
|
| 376 |
+
current_span['end'] = token_end
|
| 377 |
+
current_span['text'] = text[current_span['start']:current_span['end']]
|
| 378 |
+
else:
|
| 379 |
+
if current_span:
|
| 380 |
+
spans.append(current_span)
|
| 381 |
+
current_span = None
|
| 382 |
+
|
| 383 |
+
if current_span:
|
| 384 |
+
spans.append(current_span)
|
| 385 |
+
|
| 386 |
+
return spans
|
| 387 |
+
|
| 388 |
+
def predict(self, texts):
|
| 389 |
+
"""Vorhersage für neue Texte"""
|
| 390 |
+
if not hasattr(self, 'model'):
|
| 391 |
+
raise ValueError("Modell muss erst trainiert werden!")
|
| 392 |
+
|
| 393 |
+
predictions = []
|
| 394 |
+
device = next(self.model.parameters()).device
|
| 395 |
+
|
| 396 |
+
for text in texts:
|
| 397 |
+
# Tokenisierung
|
| 398 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True,
|
| 399 |
+
max_length=512, return_offsets_mapping=True)
|
| 400 |
+
|
| 401 |
+
offset_mapping = inputs.pop('offset_mapping')
|
| 402 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 403 |
+
|
| 404 |
+
# Vorhersage
|
| 405 |
+
with torch.no_grad():
|
| 406 |
+
outputs = self.model(**inputs)
|
| 407 |
+
|
| 408 |
+
predicted_labels = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()
|
| 409 |
+
|
| 410 |
+
# Spans extrahieren
|
| 411 |
+
spans = self._predictions_to_spans(predicted_labels, offset_mapping[0], text)
|
| 412 |
+
predictions.append({'text': text, 'spans': spans})
|
| 413 |
+
|
| 414 |
+
return predictions
|
| 415 |
+
|
| 416 |
+
def train(self, comments_df, spans_df, experiment_name):
|
| 417 |
+
wandb.init(project=os.environ["WANDB_PROJECT"], name=f"{experiment_name}",
|
| 418 |
+
group=experiment_name)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# Dataset neu erstellen für diesen Fold
|
| 422 |
+
examples, eval_data = self.create_dataset(comments_df, spans_df)
|
| 423 |
+
train_examples, val_examples = train_test_split(examples, test_size=0.1, random_state=42)
|
| 424 |
+
|
| 425 |
+
# Evaluation-Daten entsprechend aufteilen
|
| 426 |
+
train_indices, val_indices = train_test_split(range(len(examples)), test_size=0.1, random_state=42)
|
| 427 |
+
self.current_eval_data = [eval_data[i] for i in val_indices]
|
| 428 |
+
|
| 429 |
+
test_comments = comments_df.iloc[val_indices].reset_index(drop=True)
|
| 430 |
+
|
| 431 |
+
train_dataset = Dataset.from_list(train_examples)
|
| 432 |
+
val_dataset = Dataset.from_list(val_examples)
|
| 433 |
+
|
| 434 |
+
# Modell neu initialisieren
|
| 435 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 436 |
+
self.model_name,
|
| 437 |
+
num_labels=len(self.labels),
|
| 438 |
+
id2label=self.id2label,
|
| 439 |
+
label2id=self.label2id
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Training-Argumente
|
| 443 |
+
fold_output_dir = f"{experiment_name}"
|
| 444 |
+
training_args = TrainingArguments(
|
| 445 |
+
output_dir=fold_output_dir,
|
| 446 |
+
learning_rate=2e-5,
|
| 447 |
+
warmup_steps=400,
|
| 448 |
+
per_device_train_batch_size=32,
|
| 449 |
+
per_device_eval_batch_size=16,
|
| 450 |
+
num_train_epochs=20,
|
| 451 |
+
eval_strategy="steps",
|
| 452 |
+
eval_steps=40,
|
| 453 |
+
save_strategy="steps",
|
| 454 |
+
save_steps=40,
|
| 455 |
+
load_best_model_at_end=True,
|
| 456 |
+
metric_for_best_model="strict_f1",
|
| 457 |
+
greater_is_better=True,
|
| 458 |
+
logging_steps=10,
|
| 459 |
+
logging_strategy="steps",
|
| 460 |
+
report_to="all",
|
| 461 |
+
disable_tqdm=False,
|
| 462 |
+
seed=42,
|
| 463 |
+
save_total_limit=3,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Trainer
|
| 467 |
+
trainer = Trainer(
|
| 468 |
+
model=model,
|
| 469 |
+
args=training_args,
|
| 470 |
+
train_dataset=train_dataset,
|
| 471 |
+
eval_dataset=val_dataset,
|
| 472 |
+
data_collator=DataCollatorForTokenClassification(self.tokenizer),
|
| 473 |
+
compute_metrics=self.compute_metrics,
|
| 474 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=87)]
|
| 475 |
+
# 87 steps = 3.0 epochs with 29 steps per epoch
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Training
|
| 479 |
+
print(f"Training auf {len(train_dataset)} Beispielen")
|
| 480 |
+
print(f"Validation auf {len(val_dataset)} Beispielen")
|
| 481 |
+
trainer.train()
|
| 482 |
+
|
| 483 |
+
# Aktuelles Modell speichern
|
| 484 |
+
self.model = model
|
| 485 |
+
|
| 486 |
+
# Modell evaluieren auf Test-Daten
|
| 487 |
+
print(f"Evaluierung auf {len(test_comments)} Test-Beispielen")
|
| 488 |
+
metrics = self.evaluate_by_label(test_comments, spans_df)
|
| 489 |
+
wandb.log({
|
| 490 |
+
'strict_f1': metrics['TOTAL']['STRICT']['f1'],
|
| 491 |
+
'strict_precision': metrics['TOTAL']['STRICT']['prec'],
|
| 492 |
+
'strict_recall': metrics['TOTAL']['STRICT']['rec'],
|
| 493 |
+
'spans_f1': metrics['TOTAL']['SPANS']['f1'],
|
| 494 |
+
'types_f1': metrics['TOTAL']['TYPES']['f1']
|
| 495 |
+
})
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# Speichere Modell
|
| 499 |
+
torch.save(model.state_dict(), f'{fold_output_dir}_model.pth')
|
| 500 |
+
|
| 501 |
+
torch.cuda.memory.empty_cache()
|
| 502 |
+
wandb.finish()
|
| 503 |
+
|
| 504 |
+
return trainer
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def cross_validate(self, comments_df, spans_df, n_splits=5, output_dir_prefix="span-classifier-cv"):
|
| 508 |
+
"""Führe n-fache Kreuzvalidierung mit StratifiedKFold durch"""
|
| 509 |
+
|
| 510 |
+
# Erstelle Label für Stratifizierung (basierend auf dem ersten Span types eines Kommentars)
|
| 511 |
+
strat_labels = []
|
| 512 |
+
spans_grouped = spans_df.groupby(['document', 'comment_id'])
|
| 513 |
+
for _, row in comments_df.iterrows():
|
| 514 |
+
key = (row['document'], row['comment_id'])
|
| 515 |
+
# 1 wenn Kommentar Spans hat, sonst 0
|
| 516 |
+
has_spans = spans_grouped.get_group(key).iloc[0]['type'] if key in spans_grouped.groups and len(spans_grouped.get_group(key)) > 0 else 0
|
| 517 |
+
strat_labels.append(has_spans)
|
| 518 |
+
|
| 519 |
+
# Erstelle StratifiedKFold
|
| 520 |
+
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
| 521 |
+
|
| 522 |
+
# Speichere Metriken für jeden Fold
|
| 523 |
+
fold_metrics = []
|
| 524 |
+
|
| 525 |
+
# Iteriere über Folds
|
| 526 |
+
for fold, (train_idx, test_idx) in enumerate(skf.split(range(len(comments_df)), strat_labels)):
|
| 527 |
+
if '--fold' in sys.argv:
|
| 528 |
+
fold_arg = int(sys.argv[sys.argv.index('--fold') + 1])
|
| 529 |
+
if fold + 1 != fold_arg:
|
| 530 |
+
continue
|
| 531 |
+
|
| 532 |
+
wandb.init(project=os.environ["WANDB_PROJECT"], name=f"{experiment_name}-fold-{fold+1}",
|
| 533 |
+
group=experiment_name)
|
| 534 |
+
|
| 535 |
+
print(f"\n{'='*50}")
|
| 536 |
+
print(f"Fold {fold+1}/{n_splits}")
|
| 537 |
+
print(f"{'='*50}")
|
| 538 |
+
|
| 539 |
+
# Kommentare für diesen Fold
|
| 540 |
+
train_comments = comments_df.iloc[train_idx].reset_index(drop=True)
|
| 541 |
+
test_comments = comments_df.iloc[test_idx].reset_index(drop=True)
|
| 542 |
+
|
| 543 |
+
# Dataset neu erstellen für diesen Fold
|
| 544 |
+
examples, eval_data = self.create_dataset(train_comments, spans_df)
|
| 545 |
+
train_examples, val_examples = train_test_split(examples, test_size=0.1, random_state=42)
|
| 546 |
+
|
| 547 |
+
# Evaluation-Daten entsprechend aufteilen
|
| 548 |
+
train_indices, val_indices = train_test_split(range(len(examples)), test_size=0.1, random_state=42)
|
| 549 |
+
self.current_eval_data = [eval_data[i] for i in val_indices]
|
| 550 |
+
|
| 551 |
+
train_dataset = Dataset.from_list(train_examples)
|
| 552 |
+
val_dataset = Dataset.from_list(val_examples)
|
| 553 |
+
|
| 554 |
+
# Modell neu initialisieren
|
| 555 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 556 |
+
self.model_name,
|
| 557 |
+
num_labels=len(self.labels),
|
| 558 |
+
id2label=self.id2label,
|
| 559 |
+
label2id=self.label2id
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Training-Argumente
|
| 563 |
+
fold_output_dir = f"{output_dir_prefix}-fold-{fold+1}"
|
| 564 |
+
training_args = TrainingArguments(
|
| 565 |
+
output_dir=fold_output_dir,
|
| 566 |
+
learning_rate=2e-5,
|
| 567 |
+
warmup_steps=400,
|
| 568 |
+
per_device_train_batch_size=32,
|
| 569 |
+
per_device_eval_batch_size=16,
|
| 570 |
+
num_train_epochs=15,
|
| 571 |
+
eval_strategy="steps",
|
| 572 |
+
eval_steps=40,
|
| 573 |
+
save_strategy="steps",
|
| 574 |
+
save_steps=40,
|
| 575 |
+
load_best_model_at_end=True,
|
| 576 |
+
metric_for_best_model="strict_f1",
|
| 577 |
+
greater_is_better=True,
|
| 578 |
+
logging_steps=10,
|
| 579 |
+
logging_strategy="steps",
|
| 580 |
+
report_to="all",
|
| 581 |
+
disable_tqdm=False,
|
| 582 |
+
seed=42,
|
| 583 |
+
save_total_limit=3,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Trainer
|
| 587 |
+
trainer = Trainer(
|
| 588 |
+
model=model,
|
| 589 |
+
args=training_args,
|
| 590 |
+
train_dataset=train_dataset,
|
| 591 |
+
eval_dataset=val_dataset,
|
| 592 |
+
data_collator=DataCollatorForTokenClassification(self.tokenizer),
|
| 593 |
+
compute_metrics=self.compute_metrics,
|
| 594 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=87)] # 87 steps = 3.0 epochs with 29 steps per epoch
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Training
|
| 598 |
+
print(f"Training auf {len(train_dataset)} Beispielen")
|
| 599 |
+
print(f"Validation auf {len(val_dataset)} Beispielen")
|
| 600 |
+
trainer.train()
|
| 601 |
+
|
| 602 |
+
# Aktuelles Modell speichern
|
| 603 |
+
self.model = model
|
| 604 |
+
|
| 605 |
+
# Modell evaluieren auf Test-Daten
|
| 606 |
+
print(f"Evaluierung auf {len(test_comments)} Test-Beispielen")
|
| 607 |
+
flausch_results = self.evaluate_by_label(test_comments, spans_df)
|
| 608 |
+
|
| 609 |
+
# Extrahiere Hauptmetriken für fold_metrics
|
| 610 |
+
metrics = {
|
| 611 |
+
'strict_f1': flausch_results['TOTAL']['STRICT']['f1'],
|
| 612 |
+
'strict_precision': flausch_results['TOTAL']['STRICT']['prec'],
|
| 613 |
+
'strict_recall': flausch_results['TOTAL']['STRICT']['rec'],
|
| 614 |
+
'spans_f1': flausch_results['TOTAL']['SPANS']['f1'],
|
| 615 |
+
'spans_precision': flausch_results['TOTAL']['SPANS']['prec'],
|
| 616 |
+
'spans_recall': flausch_results['TOTAL']['SPANS']['rec'],
|
| 617 |
+
'types_f1': flausch_results['TOTAL']['TYPES']['f1'],
|
| 618 |
+
'types_precision': flausch_results['TOTAL']['TYPES']['prec'],
|
| 619 |
+
'types_recall': flausch_results['TOTAL']['TYPES']['rec'],
|
| 620 |
+
'full_results': flausch_results
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
fold_metrics.append(metrics)
|
| 624 |
+
wandb.log(metrics, step=fold + 1)
|
| 625 |
+
|
| 626 |
+
# Speichere Modell
|
| 627 |
+
torch.save(model.state_dict(), f'{fold_output_dir}_model.pth')
|
| 628 |
+
|
| 629 |
+
test_predictions = self.predict(test_comments['comment'].tolist())
|
| 630 |
+
|
| 631 |
+
# Speichere Metriken
|
| 632 |
+
with open(f"test_results.{experiment_name}.fold-{fold+1}.pkl", "wb") as p:
|
| 633 |
+
pickle.dump((train_comments, test_comments, test_predictions, train_examples, val_examples), p)
|
| 634 |
+
|
| 635 |
+
with open(f"scores.{experiment_name}.txt", 'a') as f:
|
| 636 |
+
f.write(f'[{time.strftime("%Y-%m-%d %H:%M:%S")}] Fold {fold+1} Ergebnisse:\n')
|
| 637 |
+
f.write(f"[{experiment_name} fold-{fold+1} {metrics}\n")
|
| 638 |
+
|
| 639 |
+
torch.cuda.memory.empty_cache()
|
| 640 |
+
wandb.finish()
|
| 641 |
+
|
| 642 |
+
# Zusammenfassung ausgeben
|
| 643 |
+
print("\n" + "="*50)
|
| 644 |
+
print("Kreuzvalidierung abgeschlossen")
|
| 645 |
+
print("="*50)
|
| 646 |
+
|
| 647 |
+
# Berechne Durchschnitts-Metriken
|
| 648 |
+
avg_f1 = np.mean([m['strict_f1'] for m in fold_metrics])
|
| 649 |
+
avg_precision = np.mean([m['strict_precision'] for m in fold_metrics])
|
| 650 |
+
avg_recall = np.mean([m['strict_recall'] for m in fold_metrics])
|
| 651 |
+
|
| 652 |
+
print(f"\nDurchschnittliche Metriken über {n_splits} Folds:")
|
| 653 |
+
print(f"Precision: {avg_precision:.10f}")
|
| 654 |
+
print(f"Recall: {avg_recall:.10f}")
|
| 655 |
+
print(f"F1-Score: {avg_f1:.10f}")
|
| 656 |
+
|
| 657 |
+
# Std-Abweichung
|
| 658 |
+
std_f1 = np.std([m['strict_f1'] for m in fold_metrics])
|
| 659 |
+
std_precision = np.std([m['strict_precision'] for m in fold_metrics])
|
| 660 |
+
std_recall = np.std([m['strict_recall'] for m in fold_metrics])
|
| 661 |
+
|
| 662 |
+
print(f"\nStandardabweichung über {n_splits} Folds:")
|
| 663 |
+
print(f"Precision: {std_precision:.10f}")
|
| 664 |
+
print(f"Recall: {std_recall:.10f}")
|
| 665 |
+
print(f"F1-Score: {std_f1:.10f}")
|
| 666 |
+
|
| 667 |
+
# Ergebnisse für jeden Fold ausgeben
|
| 668 |
+
for fold, metrics in enumerate(fold_metrics):
|
| 669 |
+
print(f"\nFold {fold+1} Ergebnisse:")
|
| 670 |
+
print(f"Precision: {metrics['strict_precision']:.4f}")
|
| 671 |
+
print(f"Recall: {metrics['strict_recall']:.4f}")
|
| 672 |
+
print(f"F1-Score: {metrics['strict_f1']:.4f}")
|
| 673 |
+
|
| 674 |
+
return {
|
| 675 |
+
'fold_metrics': fold_metrics,
|
| 676 |
+
'avg_metrics': {
|
| 677 |
+
'strict_f1': avg_f1,
|
| 678 |
+
'strict_precision': avg_precision,
|
| 679 |
+
'strict_recall': avg_recall
|
| 680 |
+
},
|
| 681 |
+
'std_metrics': {
|
| 682 |
+
'strict_f1': std_f1,
|
| 683 |
+
'strict_precision': std_precision,
|
| 684 |
+
'strict_recall': std_recall
|
| 685 |
+
}
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
# Daten laden
|
| 691 |
+
comments: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/training data/comments.csv")
|
| 692 |
+
task1: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/training data/task1.csv")
|
| 693 |
+
task2: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/training data/task2.csv")
|
| 694 |
+
comments = comments.merge(task1, on=["document", "comment_id"])
|
| 695 |
+
|
| 696 |
+
test_data: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/test data/comments.csv")
|
| 697 |
+
|
| 698 |
+
# Wähle Teilmenge der Daten für Experiment (z.B. 17000 Kommentare)
|
| 699 |
+
experiment_data = comments
|
| 700 |
+
|
| 701 |
+
# Klassifikator mit Strict F1
|
| 702 |
+
classifier = SpanClassifierWithStrictF1('deepset/gbert-large')
|
| 703 |
+
|
| 704 |
+
# 5-fold Cross-Validation durchführen
|
| 705 |
+
cv_results = classifier.cross_validate(
|
| 706 |
+
experiment_data,
|
| 707 |
+
task2,
|
| 708 |
+
n_splits=5,
|
| 709 |
+
output_dir_prefix=experiment_name
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
# write results to text file
|
| 713 |
+
with open(f"scores.{experiment_name}.txt", 'a') as f:
|
| 714 |
+
f.write(f'[{time.strftime("%Y-%m-%d %H:%M:%S")}] KFold cross validation of {experiment_name}\n')
|
| 715 |
+
f.write(f'{cv_results}\n')
|
| 716 |
+
|
| 717 |
+
# Optional: Finales Modell auf allen Daten trainieren
|
| 718 |
+
trainer = classifier.train(experiment_data, task2, f'{experiment_name}-final')
|
| 719 |
+
torch.save(classifier.model.state_dict(), f'{experiment_name}_final_model.pth')
|
| 720 |
+
|
| 721 |
+
# Test-Vorhersage mit finalem Modell
|
| 722 |
+
test_texts = ["Das ist ein toller Kommentar!", "Schlechter Text hier.",
|
| 723 |
+
"Sehr gutes Video. Danke! Ich finde Dich echt toll!", "Du bist doof!", "Das Licht ist echt gut.",
|
| 724 |
+
"Team Einhorn", "Macht unbedingt weiter so!", "Das sehe ich ganz genauso.", "Stimmt, Du hast vollkommen Recht!",
|
| 725 |
+
"Ich bin so dankbar ein #Lochinator zu sein"]
|
| 726 |
+
|
| 727 |
+
predictions = classifier.predict(test_texts)
|
| 728 |
+
|
| 729 |
+
for pred in predictions:
|
| 730 |
+
print(f"\nText: {pred['text']}")
|
| 731 |
+
for span in pred['spans']:
|
| 732 |
+
print(f" Span: '{span['text']}' ({span['start']}-{span['end']}) - {span['type']}")
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
|
subtask_2/exp027-2.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
import wandb
|
| 10 |
+
from datasets import Dataset
|
| 11 |
+
from multiset import *
|
| 12 |
+
from sklearn.model_selection import train_test_split, StratifiedKFold
|
| 13 |
+
from transformers import (
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
AutoModelForTokenClassification,
|
| 16 |
+
TrainingArguments,
|
| 17 |
+
Trainer,
|
| 18 |
+
DataCollatorForTokenClassification,
|
| 19 |
+
EarlyStoppingCallback
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 23 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
| 24 |
+
|
| 25 |
+
os.environ["WANDB_PROJECT"]="GermEval2025-Substask2"
|
| 26 |
+
os.environ["WANDB_LOG_MODEL"]="false"
|
| 27 |
+
|
| 28 |
+
experiment_name = 'exp027-2'
|
| 29 |
+
|
| 30 |
+
ALL_LABELS = ["affection declaration","agreement","ambiguous",
|
| 31 |
+
"compliment","encouragement","gratitude","group membership",
|
| 32 |
+
"implicit","positive feedback","sympathy"]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def fine_grained_flausch_by_label(gold, predicted):
|
| 36 |
+
gold['cid']= gold['document']+"_"+gold['comment_id'].apply(str)
|
| 37 |
+
predicted['cid']= predicted['document']+"_"+predicted['comment_id'].apply(str)
|
| 38 |
+
|
| 39 |
+
# annotation sets (predicted)
|
| 40 |
+
pred_spans = Multiset()
|
| 41 |
+
pred_spans_loose = Multiset()
|
| 42 |
+
pred_types = Multiset()
|
| 43 |
+
|
| 44 |
+
# annotation sets (gold)
|
| 45 |
+
gold_spans = Multiset()
|
| 46 |
+
gold_spans_loose = Multiset()
|
| 47 |
+
gold_types = Multiset()
|
| 48 |
+
|
| 49 |
+
for row in predicted.itertuples(index=False):
|
| 50 |
+
pred_spans.add((row.cid,row.type,row.start,row.end))
|
| 51 |
+
pred_spans_loose.add((row.cid,row.start,row.end))
|
| 52 |
+
pred_types.add((row.cid,row.type))
|
| 53 |
+
for row in gold.itertuples(index=False):
|
| 54 |
+
gold_spans.add((row.cid,row.type,row.start,row.end))
|
| 55 |
+
gold_spans_loose.add((row.cid,row.start,row.end))
|
| 56 |
+
gold_types.add((row.cid,row.type))
|
| 57 |
+
|
| 58 |
+
# precision = true_pos / true_pos + false_pos
|
| 59 |
+
# recall = true_pos / true_pos + false_neg
|
| 60 |
+
# f_1 = 2 * prec * rec / (prec + rec)
|
| 61 |
+
|
| 62 |
+
results = {'TOTAL': {'STRICT': {},'SPANS': {},'TYPES': {}}}
|
| 63 |
+
# label-wise evaluation (only for strict and type)
|
| 64 |
+
for label in ALL_LABELS:
|
| 65 |
+
results[label] = {'STRICT': {},'TYPES': {}}
|
| 66 |
+
gold_spans_x = set(filter(lambda x: x[1].__eq__(label), gold_spans))
|
| 67 |
+
pred_spans_x = set(filter(lambda x: x[1].__eq__(label), pred_spans))
|
| 68 |
+
gold_types_x = set(filter(lambda x: x[1].__eq__(label), gold_types))
|
| 69 |
+
pred_types_x = set(filter(lambda x: x[1].__eq__(label), pred_types))
|
| 70 |
+
|
| 71 |
+
# strict: spans + type must match
|
| 72 |
+
### NOTE: x and y / x returns 0 if x = 0 and y/x otherwise (test for zero division)
|
| 73 |
+
strict_p = float(len(pred_spans_x)) and float( len(gold_spans_x.intersection(pred_spans_x))) / len(pred_spans_x)
|
| 74 |
+
strict_r = float(len(gold_spans_x)) and float( len(gold_spans_x.intersection(pred_spans_x))) / len(gold_spans_x)
|
| 75 |
+
strict_f = (strict_p + strict_r) and 2 * strict_p * strict_r / (strict_p + strict_r)
|
| 76 |
+
results[label]['STRICT']['prec'] = strict_p
|
| 77 |
+
results[label]['STRICT']['rec'] = strict_r
|
| 78 |
+
results[label]['STRICT']['f1'] = strict_f
|
| 79 |
+
|
| 80 |
+
# detection mode: only types must match (per post)
|
| 81 |
+
types_p = float(len(pred_types_x)) and float( len(gold_types_x.intersection(pred_types_x))) / len(pred_types_x)
|
| 82 |
+
types_r = float(len(gold_types_x)) and float( len(gold_types_x.intersection(pred_types_x))) / len(gold_types_x)
|
| 83 |
+
types_f = (types_p + types_r) and 2 * types_p * types_r / (types_p + types_r)
|
| 84 |
+
results[label]['TYPES']['prec'] = types_p
|
| 85 |
+
results[label]['TYPES']['rec'] = types_r
|
| 86 |
+
results[label]['TYPES']['f1'] = types_f
|
| 87 |
+
|
| 88 |
+
# Overall evaluation
|
| 89 |
+
# strict: spans + type must match
|
| 90 |
+
strict_p = float(len(pred_spans)) and float( len(gold_spans.intersection(pred_spans))) / len(pred_spans)
|
| 91 |
+
strict_r = float(len(gold_spans)) and float( len(gold_spans.intersection(pred_spans))) / len(gold_spans)
|
| 92 |
+
strict_f = (strict_p + strict_r) and 2 * strict_p * strict_r / (strict_p + strict_r)
|
| 93 |
+
results['TOTAL']['STRICT']['prec'] = strict_p
|
| 94 |
+
results['TOTAL']['STRICT']['rec'] = strict_r
|
| 95 |
+
results['TOTAL']['STRICT']['f1'] = strict_f
|
| 96 |
+
|
| 97 |
+
# spans: spans must match
|
| 98 |
+
spans_p = float(len(pred_spans_loose)) and float( len(gold_spans_loose.intersection(pred_spans_loose))) / len(pred_spans_loose)
|
| 99 |
+
spans_r = float(len(gold_spans_loose)) and float( len(gold_spans_loose.intersection(pred_spans_loose))) / len(gold_spans_loose)
|
| 100 |
+
spans_f = (spans_p + spans_r) and 2 * spans_p * spans_r / (spans_p + spans_r)
|
| 101 |
+
results['TOTAL']['SPANS']['prec'] = spans_p
|
| 102 |
+
results['TOTAL']['SPANS']['rec'] = spans_r
|
| 103 |
+
results['TOTAL']['SPANS']['f1'] = spans_f
|
| 104 |
+
|
| 105 |
+
# detection mode: only types must match (per post)
|
| 106 |
+
types_p = float(len(pred_types)) and float( len(gold_types.intersection(pred_types))) / len(pred_types)
|
| 107 |
+
types_r = float(len(gold_types)) and float( len(gold_types.intersection(pred_types))) / len(gold_types)
|
| 108 |
+
types_f = (types_p + types_r) and 2 * types_p * types_r / (types_p + types_r)
|
| 109 |
+
results['TOTAL']['TYPES']['prec'] = types_p
|
| 110 |
+
results['TOTAL']['TYPES']['rec'] = types_r
|
| 111 |
+
results['TOTAL']['TYPES']['f1'] = types_f
|
| 112 |
+
|
| 113 |
+
return results
|
| 114 |
+
|
| 115 |
+
class SpanClassifierWithStrictF1:
|
| 116 |
+
def __init__(self, model_name="deepset/gbert-base"):
|
| 117 |
+
self.model_name = model_name
|
| 118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
|
| 119 |
+
|
| 120 |
+
self.labels =[
|
| 121 |
+
"O",
|
| 122 |
+
"B-positive feedback", "B-compliment", "B-affection declaration", "B-encouragement", "B-gratitude", "B-agreement", "B-ambiguous", "B-implicit", "B-group membership", "B-sympathy",
|
| 123 |
+
"I-positive feedback", "I-compliment", "I-affection declaration", "I-encouragement", "I-gratitude", "I-agreement", "I-ambiguous", "I-implicit", "I-group membership", "I-sympathy"
|
| 124 |
+
]
|
| 125 |
+
self.label2id = {label: i for i, label in enumerate(self.labels)}
|
| 126 |
+
self.id2label = {i: label for i, label in enumerate(self.labels)}
|
| 127 |
+
|
| 128 |
+
def create_dataset(self, comments_df, spans_df):
|
| 129 |
+
"""Erstelle Dataset mit BIO-Labels und speichere Evaluation-Daten"""
|
| 130 |
+
examples = []
|
| 131 |
+
eval_data = [] # Für Strict F1 Berechnung
|
| 132 |
+
|
| 133 |
+
spans_grouped = spans_df.groupby(['document', 'comment_id'])
|
| 134 |
+
|
| 135 |
+
for _, row in comments_df.iterrows():
|
| 136 |
+
text = row['comment']
|
| 137 |
+
document = row['document']
|
| 138 |
+
comment_id = row['comment_id']
|
| 139 |
+
key = (document, comment_id)
|
| 140 |
+
|
| 141 |
+
# True spans für diesen Kommentar
|
| 142 |
+
if key in spans_grouped.groups:
|
| 143 |
+
true_spans = [(span_type, int(start), int(end))
|
| 144 |
+
for span_type, start, end in
|
| 145 |
+
spans_grouped.get_group(key)[['type', 'start', 'end']].values]
|
| 146 |
+
else:
|
| 147 |
+
true_spans = []
|
| 148 |
+
|
| 149 |
+
# Tokenisierung
|
| 150 |
+
tokenized = self.tokenizer(text, truncation=True, max_length=512,
|
| 151 |
+
return_offsets_mapping=True)
|
| 152 |
+
|
| 153 |
+
# BIO-Labels erstellen
|
| 154 |
+
labels = self._create_bio_labels(tokenized['offset_mapping'],
|
| 155 |
+
spans_grouped.get_group(key)[['start', 'end', 'type']].values
|
| 156 |
+
if key in spans_grouped.groups else [])
|
| 157 |
+
|
| 158 |
+
examples.append({
|
| 159 |
+
'input_ids': tokenized['input_ids'],
|
| 160 |
+
'attention_mask': tokenized['attention_mask'],
|
| 161 |
+
'labels': labels
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
# Evaluation-Daten speichern
|
| 165 |
+
eval_data.append({
|
| 166 |
+
'text': text,
|
| 167 |
+
'offset_mapping': tokenized['offset_mapping'],
|
| 168 |
+
'true_spans': true_spans,
|
| 169 |
+
'document': document,
|
| 170 |
+
'comment_id': comment_id
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
return examples, eval_data
|
| 174 |
+
|
| 175 |
+
def _create_bio_labels(self, offset_mapping, spans):
|
| 176 |
+
"""Erstelle BIO-Labels für Tokens"""
|
| 177 |
+
labels = [0] * len(offset_mapping) # 0 = "O"
|
| 178 |
+
|
| 179 |
+
for start, end, type_label in spans:
|
| 180 |
+
for i, (token_start, token_end) in enumerate(offset_mapping):
|
| 181 |
+
if token_start is None: # Spezielle Tokens
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
# Token überlappt mit Span
|
| 185 |
+
if token_start < end and token_end > start:
|
| 186 |
+
if token_start <= start:
|
| 187 |
+
if labels[i] != 0:
|
| 188 |
+
# dont overwrite labels if spans are overlapping; just skip the span
|
| 189 |
+
break
|
| 190 |
+
labels[i] = self.label2id[f'B-{type_label}'] # B-compliment
|
| 191 |
+
else:
|
| 192 |
+
labels[i] = self.label2id[f'I-{type_label}'] # I-compliment
|
| 193 |
+
|
| 194 |
+
return labels
|
| 195 |
+
|
| 196 |
+
def _predictions_to_dataframe(self, predictions_list, comments_df_subset):
|
| 197 |
+
"""Konvertiere Vorhersagen zu DataFrame für Flausch-Metrik"""
|
| 198 |
+
pred_data = []
|
| 199 |
+
|
| 200 |
+
for i, pred in enumerate(predictions_list):
|
| 201 |
+
if i < len(comments_df_subset):
|
| 202 |
+
row = comments_df_subset.iloc[i]
|
| 203 |
+
document = row['document']
|
| 204 |
+
comment_id = row['comment_id']
|
| 205 |
+
|
| 206 |
+
for span in pred['spans']:
|
| 207 |
+
pred_data.append({
|
| 208 |
+
'document': document,
|
| 209 |
+
'comment_id': comment_id,
|
| 210 |
+
'type': span['type'],
|
| 211 |
+
'start': span['start'],
|
| 212 |
+
'end': span['end']
|
| 213 |
+
})
|
| 214 |
+
|
| 215 |
+
return pd.DataFrame(pred_data)
|
| 216 |
+
|
| 217 |
+
# --- helper that builds a DataFrame of spans from eval data + predictions ---
|
| 218 |
+
def _build_span_dfs(self, eval_data, batch_pred_spans):
|
| 219 |
+
"""
|
| 220 |
+
eval_data: list of dicts with keys document, comment_id, true_spans
|
| 221 |
+
batch_pred_spans: list of lists of (type, start, end)
|
| 222 |
+
returns (gold_df, pred_df) suitable for fine_grained_flausch_by_label
|
| 223 |
+
"""
|
| 224 |
+
rows_gold = []
|
| 225 |
+
rows_pred = []
|
| 226 |
+
for item, pred_spans in zip(eval_data, batch_pred_spans):
|
| 227 |
+
doc = item['document']
|
| 228 |
+
cid = item['comment_id']
|
| 229 |
+
# gold
|
| 230 |
+
for t, s, e in item['true_spans']:
|
| 231 |
+
rows_gold.append({
|
| 232 |
+
'document': doc,
|
| 233 |
+
'comment_id': cid,
|
| 234 |
+
'type': t,
|
| 235 |
+
'start': s,
|
| 236 |
+
'end': e
|
| 237 |
+
})
|
| 238 |
+
# pred
|
| 239 |
+
for t, s, e in pred_spans:
|
| 240 |
+
rows_pred.append({
|
| 241 |
+
'document': doc,
|
| 242 |
+
'comment_id': cid,
|
| 243 |
+
'type': t,
|
| 244 |
+
'start': s,
|
| 245 |
+
'end': e
|
| 246 |
+
})
|
| 247 |
+
gold_df = pd.DataFrame(rows_gold, columns=['document','comment_id','type','start','end'])
|
| 248 |
+
pred_df = pd.DataFrame(rows_pred, columns=['document','comment_id','type','start','end'])
|
| 249 |
+
return gold_df, pred_df
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def compute_metrics(self, eval_pred):
|
| 253 |
+
"""
|
| 254 |
+
Called by the HF-Trainer at each evaluation step.
|
| 255 |
+
We collect batch predictions, reconstruct gold/pred spans,
|
| 256 |
+
call fine_grained_flausch_by_label and return the TOTAL/STRICT metrics.
|
| 257 |
+
"""
|
| 258 |
+
logits, labels = eval_pred
|
| 259 |
+
preds = np.argmax(logits, axis=2)
|
| 260 |
+
|
| 261 |
+
# reconstruct spans per example in this batch
|
| 262 |
+
batch_pred_spans = []
|
| 263 |
+
for i, (p_seq, lab_seq) in enumerate(zip(preds, labels)):
|
| 264 |
+
# skip padding (-100)
|
| 265 |
+
valid_preds = []
|
| 266 |
+
valid_offsets = []
|
| 267 |
+
offsets = self.current_eval_data[i]['offset_mapping']
|
| 268 |
+
for j,(p,l) in enumerate(zip(p_seq, lab_seq)):
|
| 269 |
+
if l != -100:
|
| 270 |
+
valid_preds.append(int(p))
|
| 271 |
+
valid_offsets.append(offsets[j])
|
| 272 |
+
# convert to spans
|
| 273 |
+
pred_spans = self._predictions_to_spans(valid_preds, valid_offsets,
|
| 274 |
+
self.current_eval_data[i]['text'])
|
| 275 |
+
# to (type, start, end)-tuples
|
| 276 |
+
batch_pred_spans.append([(sp['type'], sp['start'], sp['end'])
|
| 277 |
+
for sp in pred_spans])
|
| 278 |
+
|
| 279 |
+
# build the gold/pred DataFrames
|
| 280 |
+
gold_df, pred_df = self._build_span_dfs(self.current_eval_data,
|
| 281 |
+
batch_pred_spans)
|
| 282 |
+
|
| 283 |
+
# call your fine-grained metrics
|
| 284 |
+
results = fine_grained_flausch_by_label(gold_df, pred_df)
|
| 285 |
+
|
| 286 |
+
# extract the TOTAL/STRICT metrics
|
| 287 |
+
total = results['TOTAL']['STRICT']
|
| 288 |
+
return {
|
| 289 |
+
'strict_prec': torch.tensor(total['prec'], dtype=torch.float32),
|
| 290 |
+
'strict_rec': torch.tensor(total['rec'], dtype=torch.float32),
|
| 291 |
+
'strict_f1': torch.tensor(total['f1'], dtype=torch.float32),
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def evaluate_by_label(self, comments_df, spans_df):
|
| 296 |
+
"""
|
| 297 |
+
Replace evaluate_strict_f1. Runs a full pass over all comments,
|
| 298 |
+
uses self.predict() to get spans, then calls your fine_grained_flausch_by_label
|
| 299 |
+
and prints & returns the TOTAL metrics.
|
| 300 |
+
"""
|
| 301 |
+
# 1) run predictions
|
| 302 |
+
texts = comments_df['comment'].tolist()
|
| 303 |
+
docs = comments_df['document'].tolist()
|
| 304 |
+
cids = comments_df['comment_id'].tolist()
|
| 305 |
+
preds = self.predict(texts)
|
| 306 |
+
|
| 307 |
+
# 2) build gold and pred lists
|
| 308 |
+
gold_rows = []
|
| 309 |
+
for (_, row) in comments_df.iterrows():
|
| 310 |
+
key = (row['document'], row['comment_id'])
|
| 311 |
+
# get all true spans for this comment_id
|
| 312 |
+
group = spans_df[
|
| 313 |
+
(spans_df.document==row['document']) &
|
| 314 |
+
(spans_df.comment_id==row['comment_id'])
|
| 315 |
+
]
|
| 316 |
+
for _, sp in group.iterrows():
|
| 317 |
+
gold_rows.append({
|
| 318 |
+
'document': row['document'],
|
| 319 |
+
'comment_id': row['comment_id'],
|
| 320 |
+
'type': sp['type'],
|
| 321 |
+
'start': sp['start'],
|
| 322 |
+
'end': sp['end']
|
| 323 |
+
})
|
| 324 |
+
|
| 325 |
+
pred_rows = []
|
| 326 |
+
for doc, cid, p in zip(docs, cids, preds):
|
| 327 |
+
for sp in p['spans']:
|
| 328 |
+
pred_rows.append({
|
| 329 |
+
'document': doc,
|
| 330 |
+
'comment_id': cid,
|
| 331 |
+
'type': sp['type'],
|
| 332 |
+
'start': sp['start'],
|
| 333 |
+
'end': sp['end']
|
| 334 |
+
})
|
| 335 |
+
|
| 336 |
+
gold_df = pd.DataFrame(gold_rows, columns=['document','comment_id','type','start','end'])
|
| 337 |
+
pred_df = pd.DataFrame(pred_rows, columns=['document','comment_id','type','start','end'])
|
| 338 |
+
|
| 339 |
+
# 3) call fine-grained
|
| 340 |
+
results = fine_grained_flausch_by_label(gold_df, pred_df)
|
| 341 |
+
|
| 342 |
+
# 4) extract and print
|
| 343 |
+
total = results['TOTAL']
|
| 344 |
+
print("\n=== EVALUATION BY FLAUSCH METRICS ===")
|
| 345 |
+
for mode in ['STRICT','SPANS','TYPES']:
|
| 346 |
+
m = total[mode]
|
| 347 |
+
print(f"{mode:6} P={m['prec']:.4f} R={m['rec']:.4f} F1={m['f1']:.4f}")
|
| 348 |
+
|
| 349 |
+
return results
|
| 350 |
+
|
| 351 |
+
def _predictions_to_spans(self, predicted_labels, offset_mapping, text):
|
| 352 |
+
"""Konvertiere Token-Vorhersagen zu Spans"""
|
| 353 |
+
spans = []
|
| 354 |
+
current_span = None
|
| 355 |
+
|
| 356 |
+
for i, label_id in enumerate(predicted_labels):
|
| 357 |
+
if i >= len(offset_mapping):
|
| 358 |
+
break
|
| 359 |
+
|
| 360 |
+
label = self.id2label[label_id]
|
| 361 |
+
token_start, token_end = offset_mapping[i]
|
| 362 |
+
|
| 363 |
+
if token_start is None:
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
if label.startswith('B-'):
|
| 367 |
+
if current_span:
|
| 368 |
+
spans.append(current_span)
|
| 369 |
+
current_span = {
|
| 370 |
+
'type': label[2:],
|
| 371 |
+
'start': token_start,
|
| 372 |
+
'end': token_end,
|
| 373 |
+
'text': text[token_start:token_end]
|
| 374 |
+
}
|
| 375 |
+
elif label.startswith('I-') and current_span:
|
| 376 |
+
current_span['end'] = token_end
|
| 377 |
+
current_span['text'] = text[current_span['start']:current_span['end']]
|
| 378 |
+
else:
|
| 379 |
+
if current_span:
|
| 380 |
+
spans.append(current_span)
|
| 381 |
+
current_span = None
|
| 382 |
+
|
| 383 |
+
if current_span:
|
| 384 |
+
spans.append(current_span)
|
| 385 |
+
|
| 386 |
+
return spans
|
| 387 |
+
|
| 388 |
+
def predict(self, texts):
|
| 389 |
+
"""Vorhersage für neue Texte"""
|
| 390 |
+
if not hasattr(self, 'model'):
|
| 391 |
+
raise ValueError("Modell muss erst trainiert werden!")
|
| 392 |
+
|
| 393 |
+
predictions = []
|
| 394 |
+
device = next(self.model.parameters()).device
|
| 395 |
+
|
| 396 |
+
for text in texts:
|
| 397 |
+
# Tokenisierung
|
| 398 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True,
|
| 399 |
+
max_length=512, return_offsets_mapping=True)
|
| 400 |
+
|
| 401 |
+
offset_mapping = inputs.pop('offset_mapping')
|
| 402 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 403 |
+
|
| 404 |
+
# Vorhersage
|
| 405 |
+
with torch.no_grad():
|
| 406 |
+
outputs = self.model(**inputs)
|
| 407 |
+
|
| 408 |
+
predicted_labels = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()
|
| 409 |
+
|
| 410 |
+
# Spans extrahieren
|
| 411 |
+
spans = self._predictions_to_spans(predicted_labels, offset_mapping[0], text)
|
| 412 |
+
predictions.append({'text': text, 'spans': spans})
|
| 413 |
+
|
| 414 |
+
return predictions
|
| 415 |
+
|
| 416 |
+
def train(self, comments_df, spans_df, experiment_name):
|
| 417 |
+
wandb.init(project=os.environ["WANDB_PROJECT"], name=f"{experiment_name}",
|
| 418 |
+
group=experiment_name)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# Dataset neu erstellen für diesen Fold
|
| 422 |
+
examples, eval_data = self.create_dataset(comments_df, spans_df)
|
| 423 |
+
train_examples, val_examples = train_test_split(examples, test_size=0.1, random_state=42)
|
| 424 |
+
|
| 425 |
+
# Evaluation-Daten entsprechend aufteilen
|
| 426 |
+
train_indices, val_indices = train_test_split(range(len(examples)), test_size=0.1, random_state=42)
|
| 427 |
+
self.current_eval_data = [eval_data[i] for i in val_indices]
|
| 428 |
+
|
| 429 |
+
test_comments = comments_df.iloc[val_indices].reset_index(drop=True)
|
| 430 |
+
|
| 431 |
+
train_dataset = Dataset.from_list(train_examples)
|
| 432 |
+
val_dataset = Dataset.from_list(val_examples)
|
| 433 |
+
|
| 434 |
+
# Modell neu initialisieren
|
| 435 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 436 |
+
self.model_name,
|
| 437 |
+
num_labels=len(self.labels),
|
| 438 |
+
id2label=self.id2label,
|
| 439 |
+
label2id=self.label2id
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Training-Argumente
|
| 443 |
+
fold_output_dir = f"{experiment_name}"
|
| 444 |
+
training_args = TrainingArguments(
|
| 445 |
+
output_dir=fold_output_dir,
|
| 446 |
+
learning_rate=2e-5,
|
| 447 |
+
warmup_steps=500,
|
| 448 |
+
per_device_train_batch_size=32,
|
| 449 |
+
per_device_eval_batch_size=32,
|
| 450 |
+
num_train_epochs=20,
|
| 451 |
+
eval_strategy="steps",
|
| 452 |
+
eval_steps=40,
|
| 453 |
+
save_strategy="steps",
|
| 454 |
+
save_steps=40,
|
| 455 |
+
load_best_model_at_end=True,
|
| 456 |
+
metric_for_best_model="strict_f1",
|
| 457 |
+
greater_is_better=True,
|
| 458 |
+
logging_steps=10,
|
| 459 |
+
logging_strategy="steps",
|
| 460 |
+
report_to="all",
|
| 461 |
+
disable_tqdm=False,
|
| 462 |
+
seed=42,
|
| 463 |
+
save_total_limit=3,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Trainer
|
| 467 |
+
trainer = Trainer(
|
| 468 |
+
model=model,
|
| 469 |
+
args=training_args,
|
| 470 |
+
train_dataset=train_dataset,
|
| 471 |
+
eval_dataset=val_dataset,
|
| 472 |
+
data_collator=DataCollatorForTokenClassification(self.tokenizer),
|
| 473 |
+
compute_metrics=self.compute_metrics,
|
| 474 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=87)]
|
| 475 |
+
# 87 steps = 3.0 epochs with 29 steps per epoch
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Training
|
| 479 |
+
print(f"Training auf {len(train_dataset)} Beispielen")
|
| 480 |
+
print(f"Validation auf {len(val_dataset)} Beispielen")
|
| 481 |
+
trainer.train()
|
| 482 |
+
|
| 483 |
+
# Aktuelles Modell speichern
|
| 484 |
+
self.model = model
|
| 485 |
+
|
| 486 |
+
# Modell evaluieren auf Test-Daten
|
| 487 |
+
print(f"Evaluierung auf {len(test_comments)} Test-Beispielen")
|
| 488 |
+
metrics = self.evaluate_by_label(test_comments, spans_df)
|
| 489 |
+
wandb.log({
|
| 490 |
+
'strict_f1': metrics['TOTAL']['STRICT']['f1'],
|
| 491 |
+
'strict_precision': metrics['TOTAL']['STRICT']['prec'],
|
| 492 |
+
'strict_recall': metrics['TOTAL']['STRICT']['rec'],
|
| 493 |
+
'spans_f1': metrics['TOTAL']['SPANS']['f1'],
|
| 494 |
+
'types_f1': metrics['TOTAL']['TYPES']['f1']
|
| 495 |
+
})
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# Speichere Modell
|
| 499 |
+
torch.save(model.state_dict(), f'{fold_output_dir}_model.pth')
|
| 500 |
+
|
| 501 |
+
torch.cuda.memory.empty_cache()
|
| 502 |
+
wandb.finish()
|
| 503 |
+
|
| 504 |
+
return trainer
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def cross_validate(self, comments_df, spans_df, n_splits=5, output_dir_prefix="span-classifier-cv"):
|
| 508 |
+
"""Führe n-fache Kreuzvalidierung mit StratifiedKFold durch"""
|
| 509 |
+
|
| 510 |
+
# Erstelle Label für Stratifizierung (basierend auf dem ersten Span types eines Kommentars)
|
| 511 |
+
strat_labels = []
|
| 512 |
+
spans_grouped = spans_df.groupby(['document', 'comment_id'])
|
| 513 |
+
for _, row in comments_df.iterrows():
|
| 514 |
+
key = (row['document'], row['comment_id'])
|
| 515 |
+
# 1 wenn Kommentar Spans hat, sonst 0
|
| 516 |
+
has_spans = spans_grouped.get_group(key).iloc[0]['type'] if key in spans_grouped.groups and len(spans_grouped.get_group(key)) > 0 else 0
|
| 517 |
+
strat_labels.append(has_spans)
|
| 518 |
+
|
| 519 |
+
# Erstelle StratifiedKFold
|
| 520 |
+
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
| 521 |
+
|
| 522 |
+
# Speichere Metriken für jeden Fold
|
| 523 |
+
fold_metrics = []
|
| 524 |
+
|
| 525 |
+
# Iteriere über Folds
|
| 526 |
+
for fold, (train_idx, test_idx) in enumerate(skf.split(range(len(comments_df)), strat_labels)):
|
| 527 |
+
if '--fold' in sys.argv:
|
| 528 |
+
fold_arg = int(sys.argv[sys.argv.index('--fold') + 1])
|
| 529 |
+
if fold + 1 != fold_arg:
|
| 530 |
+
continue
|
| 531 |
+
|
| 532 |
+
wandb.init(project=os.environ["WANDB_PROJECT"], name=f"{experiment_name}-fold-{fold+1}",
|
| 533 |
+
group=experiment_name)
|
| 534 |
+
|
| 535 |
+
print(f"\n{'='*50}")
|
| 536 |
+
print(f"Fold {fold+1}/{n_splits}")
|
| 537 |
+
print(f"{'='*50}")
|
| 538 |
+
|
| 539 |
+
# Kommentare für diesen Fold
|
| 540 |
+
train_comments = comments_df.iloc[train_idx].reset_index(drop=True)
|
| 541 |
+
test_comments = comments_df.iloc[test_idx].reset_index(drop=True)
|
| 542 |
+
|
| 543 |
+
# Dataset neu erstellen für diesen Fold
|
| 544 |
+
examples, eval_data = self.create_dataset(train_comments, spans_df)
|
| 545 |
+
train_examples, val_examples = train_test_split(examples, test_size=0.1, random_state=42)
|
| 546 |
+
|
| 547 |
+
# Evaluation-Daten entsprechend aufteilen
|
| 548 |
+
train_indices, val_indices = train_test_split(range(len(examples)), test_size=0.1, random_state=42)
|
| 549 |
+
self.current_eval_data = [eval_data[i] for i in val_indices]
|
| 550 |
+
|
| 551 |
+
train_dataset = Dataset.from_list(train_examples)
|
| 552 |
+
val_dataset = Dataset.from_list(val_examples)
|
| 553 |
+
|
| 554 |
+
# Modell neu initialisieren
|
| 555 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 556 |
+
self.model_name,
|
| 557 |
+
num_labels=len(self.labels),
|
| 558 |
+
id2label=self.id2label,
|
| 559 |
+
label2id=self.label2id
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Training-Argumente
|
| 563 |
+
fold_output_dir = f"{output_dir_prefix}-fold-{fold+1}"
|
| 564 |
+
training_args = TrainingArguments(
|
| 565 |
+
output_dir=fold_output_dir,
|
| 566 |
+
learning_rate=2e-5,
|
| 567 |
+
warmup_steps=500,
|
| 568 |
+
per_device_train_batch_size=32,
|
| 569 |
+
per_device_eval_batch_size=32,
|
| 570 |
+
num_train_epochs=15,
|
| 571 |
+
eval_strategy="steps",
|
| 572 |
+
eval_steps=40,
|
| 573 |
+
save_strategy="steps",
|
| 574 |
+
save_steps=40,
|
| 575 |
+
load_best_model_at_end=True,
|
| 576 |
+
metric_for_best_model="strict_f1",
|
| 577 |
+
greater_is_better=True,
|
| 578 |
+
logging_steps=10,
|
| 579 |
+
logging_strategy="steps",
|
| 580 |
+
report_to="all",
|
| 581 |
+
disable_tqdm=False,
|
| 582 |
+
seed=42,
|
| 583 |
+
save_total_limit=3,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Trainer
|
| 587 |
+
trainer = Trainer(
|
| 588 |
+
model=model,
|
| 589 |
+
args=training_args,
|
| 590 |
+
train_dataset=train_dataset,
|
| 591 |
+
eval_dataset=val_dataset,
|
| 592 |
+
data_collator=DataCollatorForTokenClassification(self.tokenizer),
|
| 593 |
+
compute_metrics=self.compute_metrics,
|
| 594 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=87)] # 87 steps = 3.0 epochs with 29 steps per epoch
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Training
|
| 598 |
+
print(f"Training auf {len(train_dataset)} Beispielen")
|
| 599 |
+
print(f"Validation auf {len(val_dataset)} Beispielen")
|
| 600 |
+
trainer.train()
|
| 601 |
+
|
| 602 |
+
# Aktuelles Modell speichern
|
| 603 |
+
self.model = model
|
| 604 |
+
|
| 605 |
+
# Modell evaluieren auf Test-Daten
|
| 606 |
+
print(f"Evaluierung auf {len(test_comments)} Test-Beispielen")
|
| 607 |
+
flausch_results = self.evaluate_by_label(test_comments, spans_df)
|
| 608 |
+
|
| 609 |
+
# Extrahiere Hauptmetriken für fold_metrics
|
| 610 |
+
metrics = {
|
| 611 |
+
'strict_f1': flausch_results['TOTAL']['STRICT']['f1'],
|
| 612 |
+
'strict_precision': flausch_results['TOTAL']['STRICT']['prec'],
|
| 613 |
+
'strict_recall': flausch_results['TOTAL']['STRICT']['rec'],
|
| 614 |
+
'spans_f1': flausch_results['TOTAL']['SPANS']['f1'],
|
| 615 |
+
'spans_precision': flausch_results['TOTAL']['SPANS']['prec'],
|
| 616 |
+
'spans_recall': flausch_results['TOTAL']['SPANS']['rec'],
|
| 617 |
+
'types_f1': flausch_results['TOTAL']['TYPES']['f1'],
|
| 618 |
+
'types_precision': flausch_results['TOTAL']['TYPES']['prec'],
|
| 619 |
+
'types_recall': flausch_results['TOTAL']['TYPES']['rec'],
|
| 620 |
+
'full_results': flausch_results
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
fold_metrics.append(metrics)
|
| 624 |
+
wandb.log(metrics, step=fold + 1)
|
| 625 |
+
|
| 626 |
+
# Speichere Modell
|
| 627 |
+
torch.save(model.state_dict(), f'{fold_output_dir}_model.pth')
|
| 628 |
+
|
| 629 |
+
test_predictions = self.predict(test_comments['comment'].tolist())
|
| 630 |
+
|
| 631 |
+
# Speichere Metriken
|
| 632 |
+
with open(f"test_results.{experiment_name}.fold-{fold+1}.pkl", "wb") as p:
|
| 633 |
+
pickle.dump((train_comments, test_comments, test_predictions, train_examples, val_examples), p)
|
| 634 |
+
|
| 635 |
+
with open(f"scores.{experiment_name}.txt", 'a') as f:
|
| 636 |
+
f.write(f'[{time.strftime("%Y-%m-%d %H:%M:%S")}] Fold {fold+1} Ergebnisse:\n')
|
| 637 |
+
f.write(f"[{experiment_name} fold-{fold+1} {metrics}\n")
|
| 638 |
+
|
| 639 |
+
torch.cuda.memory.empty_cache()
|
| 640 |
+
wandb.finish()
|
| 641 |
+
|
| 642 |
+
# Zusammenfassung ausgeben
|
| 643 |
+
print("\n" + "="*50)
|
| 644 |
+
print("Kreuzvalidierung abgeschlossen")
|
| 645 |
+
print("="*50)
|
| 646 |
+
|
| 647 |
+
# Berechne Durchschnitts-Metriken
|
| 648 |
+
avg_f1 = np.mean([m['strict_f1'] for m in fold_metrics])
|
| 649 |
+
avg_precision = np.mean([m['strict_precision'] for m in fold_metrics])
|
| 650 |
+
avg_recall = np.mean([m['strict_recall'] for m in fold_metrics])
|
| 651 |
+
|
| 652 |
+
print(f"\nDurchschnittliche Metriken über {n_splits} Folds:")
|
| 653 |
+
print(f"Precision: {avg_precision:.10f}")
|
| 654 |
+
print(f"Recall: {avg_recall:.10f}")
|
| 655 |
+
print(f"F1-Score: {avg_f1:.10f}")
|
| 656 |
+
|
| 657 |
+
# Std-Abweichung
|
| 658 |
+
std_f1 = np.std([m['strict_f1'] for m in fold_metrics])
|
| 659 |
+
std_precision = np.std([m['strict_precision'] for m in fold_metrics])
|
| 660 |
+
std_recall = np.std([m['strict_recall'] for m in fold_metrics])
|
| 661 |
+
|
| 662 |
+
print(f"\nStandardabweichung über {n_splits} Folds:")
|
| 663 |
+
print(f"Precision: {std_precision:.10f}")
|
| 664 |
+
print(f"Recall: {std_recall:.10f}")
|
| 665 |
+
print(f"F1-Score: {std_f1:.10f}")
|
| 666 |
+
|
| 667 |
+
# Ergebnisse für jeden Fold ausgeben
|
| 668 |
+
for fold, metrics in enumerate(fold_metrics):
|
| 669 |
+
print(f"\nFold {fold+1} Ergebnisse:")
|
| 670 |
+
print(f"Precision: {metrics['strict_precision']:.4f}")
|
| 671 |
+
print(f"Recall: {metrics['strict_recall']:.4f}")
|
| 672 |
+
print(f"F1-Score: {metrics['strict_f1']:.4f}")
|
| 673 |
+
|
| 674 |
+
return {
|
| 675 |
+
'fold_metrics': fold_metrics,
|
| 676 |
+
'avg_metrics': {
|
| 677 |
+
'strict_f1': avg_f1,
|
| 678 |
+
'strict_precision': avg_precision,
|
| 679 |
+
'strict_recall': avg_recall
|
| 680 |
+
},
|
| 681 |
+
'std_metrics': {
|
| 682 |
+
'strict_f1': std_f1,
|
| 683 |
+
'strict_precision': std_precision,
|
| 684 |
+
'strict_recall': std_recall
|
| 685 |
+
}
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
# Daten laden
|
| 691 |
+
comments: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/training data/comments.csv")
|
| 692 |
+
task1: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/training data/task1.csv")
|
| 693 |
+
task2: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/training data/task2.csv")
|
| 694 |
+
comments = comments.merge(task1, on=["document", "comment_id"])
|
| 695 |
+
|
| 696 |
+
test_data: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/test data/comments.csv")
|
| 697 |
+
|
| 698 |
+
# Wähle Teilmenge der Daten für Experiment (z.B. 17000 Kommentare)
|
| 699 |
+
experiment_data = comments
|
| 700 |
+
|
| 701 |
+
# Klassifikator mit Strict F1
|
| 702 |
+
classifier = SpanClassifierWithStrictF1('xlm-roberta-large')
|
| 703 |
+
|
| 704 |
+
# 5-fold Cross-Validation durchführen
|
| 705 |
+
cv_results = classifier.cross_validate(
|
| 706 |
+
experiment_data,
|
| 707 |
+
task2,
|
| 708 |
+
n_splits=5,
|
| 709 |
+
output_dir_prefix=experiment_name
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
# write results to text file
|
| 713 |
+
with open(f"scores.{experiment_name}.txt", 'a') as f:
|
| 714 |
+
f.write(f'[{time.strftime("%Y-%m-%d %H:%M:%S")}] KFold cross validation of {experiment_name}\n')
|
| 715 |
+
f.write(f'{cv_results}\n')
|
| 716 |
+
|
| 717 |
+
# Optional: Finales Modell auf allen Daten trainieren
|
| 718 |
+
trainer = classifier.train(experiment_data, task2, f'{experiment_name}-final')
|
| 719 |
+
torch.save(classifier.model.state_dict(), f'{experiment_name}_final_model.pth')
|
| 720 |
+
|
| 721 |
+
# Test-Vorhersage mit finalem Modell
|
| 722 |
+
test_texts = ["Das ist ein toller Kommentar!", "Schlechter Text hier.",
|
| 723 |
+
"Sehr gutes Video. Danke! Ich finde Dich echt toll!", "Du bist doof!", "Das Licht ist echt gut.",
|
| 724 |
+
"Team Einhorn", "Macht unbedingt weiter so!", "Das sehe ich ganz genauso.", "Stimmt, Du hast vollkommen Recht!",
|
| 725 |
+
"Ich bin so dankbar ein #Lochinator zu sein"]
|
| 726 |
+
|
| 727 |
+
predictions = classifier.predict(test_texts)
|
| 728 |
+
|
| 729 |
+
for pred in predictions:
|
| 730 |
+
print(f"\nText: {pred['text']}")
|
| 731 |
+
for span in pred['spans']:
|
| 732 |
+
print(f" Span: '{span['text']}' ({span['start']}-{span['end']}) - {span['type']}")
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
|
subtask_2/exp027-2_retraining.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
import wandb
|
| 10 |
+
from datasets import Dataset
|
| 11 |
+
from multiset import *
|
| 12 |
+
from sklearn.model_selection import train_test_split, StratifiedKFold
|
| 13 |
+
from transformers import (
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
AutoModelForTokenClassification,
|
| 16 |
+
TrainingArguments,
|
| 17 |
+
Trainer,
|
| 18 |
+
DataCollatorForTokenClassification,
|
| 19 |
+
EarlyStoppingCallback
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 23 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
|
| 24 |
+
|
| 25 |
+
os.environ["WANDB_PROJECT"]="GermEval2025-Substask2"
|
| 26 |
+
os.environ["WANDB_LOG_MODEL"]="false"
|
| 27 |
+
|
| 28 |
+
experiment_name = 'exp027-2_retraining'
|
| 29 |
+
|
| 30 |
+
ALL_LABELS = ["affection declaration","agreement","ambiguous",
|
| 31 |
+
"compliment","encouragement","gratitude","group membership",
|
| 32 |
+
"implicit","positive feedback","sympathy"]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def fine_grained_flausch_by_label(gold, predicted):
|
| 36 |
+
gold['cid']= gold['document']+"_"+gold['comment_id'].apply(str)
|
| 37 |
+
predicted['cid']= predicted['document']+"_"+predicted['comment_id'].apply(str)
|
| 38 |
+
|
| 39 |
+
# annotation sets (predicted)
|
| 40 |
+
pred_spans = Multiset()
|
| 41 |
+
pred_spans_loose = Multiset()
|
| 42 |
+
pred_types = Multiset()
|
| 43 |
+
|
| 44 |
+
# annotation sets (gold)
|
| 45 |
+
gold_spans = Multiset()
|
| 46 |
+
gold_spans_loose = Multiset()
|
| 47 |
+
gold_types = Multiset()
|
| 48 |
+
|
| 49 |
+
for row in predicted.itertuples(index=False):
|
| 50 |
+
pred_spans.add((row.cid,row.type,row.start,row.end))
|
| 51 |
+
pred_spans_loose.add((row.cid,row.start,row.end))
|
| 52 |
+
pred_types.add((row.cid,row.type))
|
| 53 |
+
for row in gold.itertuples(index=False):
|
| 54 |
+
gold_spans.add((row.cid,row.type,row.start,row.end))
|
| 55 |
+
gold_spans_loose.add((row.cid,row.start,row.end))
|
| 56 |
+
gold_types.add((row.cid,row.type))
|
| 57 |
+
|
| 58 |
+
# precision = true_pos / true_pos + false_pos
|
| 59 |
+
# recall = true_pos / true_pos + false_neg
|
| 60 |
+
# f_1 = 2 * prec * rec / (prec + rec)
|
| 61 |
+
|
| 62 |
+
results = {'TOTAL': {'STRICT': {},'SPANS': {},'TYPES': {}}}
|
| 63 |
+
# label-wise evaluation (only for strict and type)
|
| 64 |
+
for label in ALL_LABELS:
|
| 65 |
+
results[label] = {'STRICT': {},'TYPES': {}}
|
| 66 |
+
gold_spans_x = set(filter(lambda x: x[1].__eq__(label), gold_spans))
|
| 67 |
+
pred_spans_x = set(filter(lambda x: x[1].__eq__(label), pred_spans))
|
| 68 |
+
gold_types_x = set(filter(lambda x: x[1].__eq__(label), gold_types))
|
| 69 |
+
pred_types_x = set(filter(lambda x: x[1].__eq__(label), pred_types))
|
| 70 |
+
|
| 71 |
+
# strict: spans + type must match
|
| 72 |
+
### NOTE: x and y / x returns 0 if x = 0 and y/x otherwise (test for zero division)
|
| 73 |
+
strict_p = float(len(pred_spans_x)) and float( len(gold_spans_x.intersection(pred_spans_x))) / len(pred_spans_x)
|
| 74 |
+
strict_r = float(len(gold_spans_x)) and float( len(gold_spans_x.intersection(pred_spans_x))) / len(gold_spans_x)
|
| 75 |
+
strict_f = (strict_p + strict_r) and 2 * strict_p * strict_r / (strict_p + strict_r)
|
| 76 |
+
results[label]['STRICT']['prec'] = strict_p
|
| 77 |
+
results[label]['STRICT']['rec'] = strict_r
|
| 78 |
+
results[label]['STRICT']['f1'] = strict_f
|
| 79 |
+
|
| 80 |
+
# detection mode: only types must match (per post)
|
| 81 |
+
types_p = float(len(pred_types_x)) and float( len(gold_types_x.intersection(pred_types_x))) / len(pred_types_x)
|
| 82 |
+
types_r = float(len(gold_types_x)) and float( len(gold_types_x.intersection(pred_types_x))) / len(gold_types_x)
|
| 83 |
+
types_f = (types_p + types_r) and 2 * types_p * types_r / (types_p + types_r)
|
| 84 |
+
results[label]['TYPES']['prec'] = types_p
|
| 85 |
+
results[label]['TYPES']['rec'] = types_r
|
| 86 |
+
results[label]['TYPES']['f1'] = types_f
|
| 87 |
+
|
| 88 |
+
# Overall evaluation
|
| 89 |
+
# strict: spans + type must match
|
| 90 |
+
strict_p = float(len(pred_spans)) and float( len(gold_spans.intersection(pred_spans))) / len(pred_spans)
|
| 91 |
+
strict_r = float(len(gold_spans)) and float( len(gold_spans.intersection(pred_spans))) / len(gold_spans)
|
| 92 |
+
strict_f = (strict_p + strict_r) and 2 * strict_p * strict_r / (strict_p + strict_r)
|
| 93 |
+
results['TOTAL']['STRICT']['prec'] = strict_p
|
| 94 |
+
results['TOTAL']['STRICT']['rec'] = strict_r
|
| 95 |
+
results['TOTAL']['STRICT']['f1'] = strict_f
|
| 96 |
+
|
| 97 |
+
# spans: spans must match
|
| 98 |
+
spans_p = float(len(pred_spans_loose)) and float( len(gold_spans_loose.intersection(pred_spans_loose))) / len(pred_spans_loose)
|
| 99 |
+
spans_r = float(len(gold_spans_loose)) and float( len(gold_spans_loose.intersection(pred_spans_loose))) / len(gold_spans_loose)
|
| 100 |
+
spans_f = (spans_p + spans_r) and 2 * spans_p * spans_r / (spans_p + spans_r)
|
| 101 |
+
results['TOTAL']['SPANS']['prec'] = spans_p
|
| 102 |
+
results['TOTAL']['SPANS']['rec'] = spans_r
|
| 103 |
+
results['TOTAL']['SPANS']['f1'] = spans_f
|
| 104 |
+
|
| 105 |
+
# detection mode: only types must match (per post)
|
| 106 |
+
types_p = float(len(pred_types)) and float( len(gold_types.intersection(pred_types))) / len(pred_types)
|
| 107 |
+
types_r = float(len(gold_types)) and float( len(gold_types.intersection(pred_types))) / len(gold_types)
|
| 108 |
+
types_f = (types_p + types_r) and 2 * types_p * types_r / (types_p + types_r)
|
| 109 |
+
results['TOTAL']['TYPES']['prec'] = types_p
|
| 110 |
+
results['TOTAL']['TYPES']['rec'] = types_r
|
| 111 |
+
results['TOTAL']['TYPES']['f1'] = types_f
|
| 112 |
+
|
| 113 |
+
return results
|
| 114 |
+
|
| 115 |
+
class SpanClassifierWithStrictF1:
|
| 116 |
+
def __init__(self, model_name="deepset/gbert-base"):
|
| 117 |
+
self.model_name = model_name
|
| 118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
|
| 119 |
+
|
| 120 |
+
self.labels =[
|
| 121 |
+
"O",
|
| 122 |
+
"B-positive feedback", "B-compliment", "B-affection declaration", "B-encouragement", "B-gratitude", "B-agreement", "B-ambiguous", "B-implicit", "B-group membership", "B-sympathy",
|
| 123 |
+
"I-positive feedback", "I-compliment", "I-affection declaration", "I-encouragement", "I-gratitude", "I-agreement", "I-ambiguous", "I-implicit", "I-group membership", "I-sympathy"
|
| 124 |
+
]
|
| 125 |
+
self.label2id = {label: i for i, label in enumerate(self.labels)}
|
| 126 |
+
self.id2label = {i: label for i, label in enumerate(self.labels)}
|
| 127 |
+
|
| 128 |
+
def create_dataset(self, comments_df, spans_df):
|
| 129 |
+
"""Erstelle Dataset mit BIO-Labels und speichere Evaluation-Daten"""
|
| 130 |
+
examples = []
|
| 131 |
+
eval_data = [] # Für Strict F1 Berechnung
|
| 132 |
+
|
| 133 |
+
spans_grouped = spans_df.groupby(['document', 'comment_id'])
|
| 134 |
+
|
| 135 |
+
for _, row in comments_df.iterrows():
|
| 136 |
+
text = row['comment']
|
| 137 |
+
document = row['document']
|
| 138 |
+
comment_id = row['comment_id']
|
| 139 |
+
key = (document, comment_id)
|
| 140 |
+
|
| 141 |
+
# True spans für diesen Kommentar
|
| 142 |
+
if key in spans_grouped.groups:
|
| 143 |
+
true_spans = [(span_type, int(start), int(end))
|
| 144 |
+
for span_type, start, end in
|
| 145 |
+
spans_grouped.get_group(key)[['type', 'start', 'end']].values]
|
| 146 |
+
else:
|
| 147 |
+
true_spans = []
|
| 148 |
+
|
| 149 |
+
# Tokenisierung
|
| 150 |
+
tokenized = self.tokenizer(text, truncation=True, max_length=512,
|
| 151 |
+
return_offsets_mapping=True)
|
| 152 |
+
|
| 153 |
+
# BIO-Labels erstellen
|
| 154 |
+
labels = self._create_bio_labels(tokenized['offset_mapping'],
|
| 155 |
+
spans_grouped.get_group(key)[['start', 'end', 'type']].values
|
| 156 |
+
if key in spans_grouped.groups else [])
|
| 157 |
+
|
| 158 |
+
examples.append({
|
| 159 |
+
'input_ids': tokenized['input_ids'],
|
| 160 |
+
'attention_mask': tokenized['attention_mask'],
|
| 161 |
+
'labels': labels
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
# Evaluation-Daten speichern
|
| 165 |
+
eval_data.append({
|
| 166 |
+
'text': text,
|
| 167 |
+
'offset_mapping': tokenized['offset_mapping'],
|
| 168 |
+
'true_spans': true_spans,
|
| 169 |
+
'document': document,
|
| 170 |
+
'comment_id': comment_id
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
return examples, eval_data
|
| 174 |
+
|
| 175 |
+
def _create_bio_labels(self, offset_mapping, spans):
|
| 176 |
+
"""Erstelle BIO-Labels für Tokens"""
|
| 177 |
+
labels = [0] * len(offset_mapping) # 0 = "O"
|
| 178 |
+
|
| 179 |
+
for start, end, type_label in spans:
|
| 180 |
+
for i, (token_start, token_end) in enumerate(offset_mapping):
|
| 181 |
+
if token_start is None: # Spezielle Tokens
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
# Token überlappt mit Span
|
| 185 |
+
if token_start < end and token_end > start:
|
| 186 |
+
if token_start <= start:
|
| 187 |
+
if labels[i] != 0:
|
| 188 |
+
# dont overwrite labels if spans are overlapping; just skip the span
|
| 189 |
+
break
|
| 190 |
+
labels[i] = self.label2id[f'B-{type_label}'] # B-compliment
|
| 191 |
+
else:
|
| 192 |
+
labels[i] = self.label2id[f'I-{type_label}'] # I-compliment
|
| 193 |
+
|
| 194 |
+
return labels
|
| 195 |
+
|
| 196 |
+
def _predictions_to_dataframe(self, predictions_list, comments_df_subset):
|
| 197 |
+
"""Konvertiere Vorhersagen zu DataFrame für Flausch-Metrik"""
|
| 198 |
+
pred_data = []
|
| 199 |
+
|
| 200 |
+
for i, pred in enumerate(predictions_list):
|
| 201 |
+
if i < len(comments_df_subset):
|
| 202 |
+
row = comments_df_subset.iloc[i]
|
| 203 |
+
document = row['document']
|
| 204 |
+
comment_id = row['comment_id']
|
| 205 |
+
|
| 206 |
+
for span in pred['spans']:
|
| 207 |
+
pred_data.append({
|
| 208 |
+
'document': document,
|
| 209 |
+
'comment_id': comment_id,
|
| 210 |
+
'type': span['type'],
|
| 211 |
+
'start': span['start'],
|
| 212 |
+
'end': span['end']
|
| 213 |
+
})
|
| 214 |
+
|
| 215 |
+
return pd.DataFrame(pred_data)
|
| 216 |
+
|
| 217 |
+
# --- helper that builds a DataFrame of spans from eval data + predictions ---
|
| 218 |
+
def _build_span_dfs(self, eval_data, batch_pred_spans):
|
| 219 |
+
"""
|
| 220 |
+
eval_data: list of dicts with keys document, comment_id, true_spans
|
| 221 |
+
batch_pred_spans: list of lists of (type, start, end)
|
| 222 |
+
returns (gold_df, pred_df) suitable for fine_grained_flausch_by_label
|
| 223 |
+
"""
|
| 224 |
+
rows_gold = []
|
| 225 |
+
rows_pred = []
|
| 226 |
+
for item, pred_spans in zip(eval_data, batch_pred_spans):
|
| 227 |
+
doc = item['document']
|
| 228 |
+
cid = item['comment_id']
|
| 229 |
+
# gold
|
| 230 |
+
for t, s, e in item['true_spans']:
|
| 231 |
+
rows_gold.append({
|
| 232 |
+
'document': doc,
|
| 233 |
+
'comment_id': cid,
|
| 234 |
+
'type': t,
|
| 235 |
+
'start': s,
|
| 236 |
+
'end': e
|
| 237 |
+
})
|
| 238 |
+
# pred
|
| 239 |
+
for t, s, e in pred_spans:
|
| 240 |
+
rows_pred.append({
|
| 241 |
+
'document': doc,
|
| 242 |
+
'comment_id': cid,
|
| 243 |
+
'type': t,
|
| 244 |
+
'start': s,
|
| 245 |
+
'end': e
|
| 246 |
+
})
|
| 247 |
+
gold_df = pd.DataFrame(rows_gold, columns=['document','comment_id','type','start','end'])
|
| 248 |
+
pred_df = pd.DataFrame(rows_pred, columns=['document','comment_id','type','start','end'])
|
| 249 |
+
return gold_df, pred_df
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def compute_metrics(self, eval_pred):
|
| 253 |
+
"""
|
| 254 |
+
Called by the HF-Trainer at each evaluation step.
|
| 255 |
+
We collect batch predictions, reconstruct gold/pred spans,
|
| 256 |
+
call fine_grained_flausch_by_label and return the TOTAL/STRICT metrics.
|
| 257 |
+
"""
|
| 258 |
+
logits, labels = eval_pred
|
| 259 |
+
preds = np.argmax(logits, axis=2)
|
| 260 |
+
|
| 261 |
+
# reconstruct spans per example in this batch
|
| 262 |
+
batch_pred_spans = []
|
| 263 |
+
for i, (p_seq, lab_seq) in enumerate(zip(preds, labels)):
|
| 264 |
+
# skip padding (-100)
|
| 265 |
+
valid_preds = []
|
| 266 |
+
valid_offsets = []
|
| 267 |
+
offsets = self.current_eval_data[i]['offset_mapping']
|
| 268 |
+
for j,(p,l) in enumerate(zip(p_seq, lab_seq)):
|
| 269 |
+
if l != -100:
|
| 270 |
+
valid_preds.append(int(p))
|
| 271 |
+
valid_offsets.append(offsets[j])
|
| 272 |
+
# convert to spans
|
| 273 |
+
pred_spans = self._predictions_to_spans(valid_preds, valid_offsets,
|
| 274 |
+
self.current_eval_data[i]['text'])
|
| 275 |
+
# to (type, start, end)-tuples
|
| 276 |
+
batch_pred_spans.append([(sp['type'], sp['start'], sp['end'])
|
| 277 |
+
for sp in pred_spans])
|
| 278 |
+
|
| 279 |
+
# build the gold/pred DataFrames
|
| 280 |
+
gold_df, pred_df = self._build_span_dfs(self.current_eval_data,
|
| 281 |
+
batch_pred_spans)
|
| 282 |
+
|
| 283 |
+
# call your fine-grained metrics
|
| 284 |
+
results = fine_grained_flausch_by_label(gold_df, pred_df)
|
| 285 |
+
|
| 286 |
+
# extract the TOTAL/STRICT metrics
|
| 287 |
+
total = results['TOTAL']['STRICT']
|
| 288 |
+
return {
|
| 289 |
+
'strict_prec': torch.tensor(total['prec'], dtype=torch.float32),
|
| 290 |
+
'strict_rec': torch.tensor(total['rec'], dtype=torch.float32),
|
| 291 |
+
'strict_f1': torch.tensor(total['f1'], dtype=torch.float32),
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def evaluate_by_label(self, comments_df, spans_df):
|
| 296 |
+
"""
|
| 297 |
+
Replace evaluate_strict_f1. Runs a full pass over all comments,
|
| 298 |
+
uses self.predict() to get spans, then calls your fine_grained_flausch_by_label
|
| 299 |
+
and prints & returns the TOTAL metrics.
|
| 300 |
+
"""
|
| 301 |
+
# 1) run predictions
|
| 302 |
+
texts = comments_df['comment'].tolist()
|
| 303 |
+
docs = comments_df['document'].tolist()
|
| 304 |
+
cids = comments_df['comment_id'].tolist()
|
| 305 |
+
preds = self.predict(texts)
|
| 306 |
+
|
| 307 |
+
# 2) build gold and pred lists
|
| 308 |
+
gold_rows = []
|
| 309 |
+
for (_, row) in comments_df.iterrows():
|
| 310 |
+
key = (row['document'], row['comment_id'])
|
| 311 |
+
# get all true spans for this comment_id
|
| 312 |
+
group = spans_df[
|
| 313 |
+
(spans_df.document==row['document']) &
|
| 314 |
+
(spans_df.comment_id==row['comment_id'])
|
| 315 |
+
]
|
| 316 |
+
for _, sp in group.iterrows():
|
| 317 |
+
gold_rows.append({
|
| 318 |
+
'document': row['document'],
|
| 319 |
+
'comment_id': row['comment_id'],
|
| 320 |
+
'type': sp['type'],
|
| 321 |
+
'start': sp['start'],
|
| 322 |
+
'end': sp['end']
|
| 323 |
+
})
|
| 324 |
+
|
| 325 |
+
pred_rows = []
|
| 326 |
+
for doc, cid, p in zip(docs, cids, preds):
|
| 327 |
+
for sp in p['spans']:
|
| 328 |
+
pred_rows.append({
|
| 329 |
+
'document': doc,
|
| 330 |
+
'comment_id': cid,
|
| 331 |
+
'type': sp['type'],
|
| 332 |
+
'start': sp['start'],
|
| 333 |
+
'end': sp['end']
|
| 334 |
+
})
|
| 335 |
+
|
| 336 |
+
gold_df = pd.DataFrame(gold_rows, columns=['document','comment_id','type','start','end'])
|
| 337 |
+
pred_df = pd.DataFrame(pred_rows, columns=['document','comment_id','type','start','end'])
|
| 338 |
+
|
| 339 |
+
# 3) call fine-grained
|
| 340 |
+
results = fine_grained_flausch_by_label(gold_df, pred_df)
|
| 341 |
+
|
| 342 |
+
# 4) extract and print
|
| 343 |
+
total = results['TOTAL']
|
| 344 |
+
print("\n=== EVALUATION BY FLAUSCH METRICS ===")
|
| 345 |
+
for mode in ['STRICT','SPANS','TYPES']:
|
| 346 |
+
m = total[mode]
|
| 347 |
+
print(f"{mode:6} P={m['prec']:.4f} R={m['rec']:.4f} F1={m['f1']:.4f}")
|
| 348 |
+
|
| 349 |
+
return results
|
| 350 |
+
|
| 351 |
+
def _predictions_to_spans(self, predicted_labels, offset_mapping, text):
|
| 352 |
+
"""Konvertiere Token-Vorhersagen zu Spans"""
|
| 353 |
+
spans = []
|
| 354 |
+
current_span = None
|
| 355 |
+
|
| 356 |
+
for i, label_id in enumerate(predicted_labels):
|
| 357 |
+
if i >= len(offset_mapping):
|
| 358 |
+
break
|
| 359 |
+
|
| 360 |
+
label = self.id2label[label_id]
|
| 361 |
+
token_start, token_end = offset_mapping[i]
|
| 362 |
+
|
| 363 |
+
if token_start is None:
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
if label.startswith('B-'):
|
| 367 |
+
if current_span:
|
| 368 |
+
spans.append(current_span)
|
| 369 |
+
current_span = {
|
| 370 |
+
'type': label[2:],
|
| 371 |
+
'start': token_start,
|
| 372 |
+
'end': token_end,
|
| 373 |
+
'text': text[token_start:token_end]
|
| 374 |
+
}
|
| 375 |
+
elif label.startswith('I-') and current_span:
|
| 376 |
+
current_span['end'] = token_end
|
| 377 |
+
current_span['text'] = text[current_span['start']:current_span['end']]
|
| 378 |
+
else:
|
| 379 |
+
if current_span:
|
| 380 |
+
spans.append(current_span)
|
| 381 |
+
current_span = None
|
| 382 |
+
|
| 383 |
+
if current_span:
|
| 384 |
+
spans.append(current_span)
|
| 385 |
+
|
| 386 |
+
return spans
|
| 387 |
+
|
| 388 |
+
def predict(self, texts):
|
| 389 |
+
"""Vorhersage für neue Texte"""
|
| 390 |
+
if not hasattr(self, 'model'):
|
| 391 |
+
raise ValueError("Modell muss erst trainiert werden!")
|
| 392 |
+
|
| 393 |
+
predictions = []
|
| 394 |
+
device = next(self.model.parameters()).device
|
| 395 |
+
|
| 396 |
+
for text in texts:
|
| 397 |
+
# Tokenisierung
|
| 398 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True,
|
| 399 |
+
max_length=512, return_offsets_mapping=True)
|
| 400 |
+
|
| 401 |
+
offset_mapping = inputs.pop('offset_mapping')
|
| 402 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 403 |
+
|
| 404 |
+
# Vorhersage
|
| 405 |
+
with torch.no_grad():
|
| 406 |
+
outputs = self.model(**inputs)
|
| 407 |
+
|
| 408 |
+
predicted_labels = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()
|
| 409 |
+
|
| 410 |
+
# Spans extrahieren
|
| 411 |
+
spans = self._predictions_to_spans(predicted_labels, offset_mapping[0], text)
|
| 412 |
+
predictions.append({'text': text, 'spans': spans})
|
| 413 |
+
|
| 414 |
+
return predictions
|
| 415 |
+
|
| 416 |
+
def train(self, comments_df, spans_df, experiment_name):
|
| 417 |
+
wandb.init(project=os.environ["WANDB_PROJECT"], name=f"{experiment_name}",
|
| 418 |
+
group=experiment_name)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# Dataset neu erstellen für diesen Fold
|
| 422 |
+
examples, eval_data = self.create_dataset(comments_df, spans_df)
|
| 423 |
+
train_examples, val_examples = train_test_split(examples, test_size=0.1, random_state=42)
|
| 424 |
+
|
| 425 |
+
# Evaluation-Daten entsprechend aufteilen
|
| 426 |
+
train_indices, val_indices = train_test_split(range(len(examples)), test_size=0.1, random_state=42)
|
| 427 |
+
self.current_eval_data = [eval_data[i] for i in val_indices]
|
| 428 |
+
|
| 429 |
+
test_comments = comments_df.iloc[val_indices].reset_index(drop=True)
|
| 430 |
+
|
| 431 |
+
train_dataset = Dataset.from_list(train_examples)
|
| 432 |
+
val_dataset = Dataset.from_list(val_examples)
|
| 433 |
+
|
| 434 |
+
# Modell neu initialisieren
|
| 435 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 436 |
+
self.model_name,
|
| 437 |
+
num_labels=len(self.labels),
|
| 438 |
+
id2label=self.id2label,
|
| 439 |
+
label2id=self.label2id
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Training-Argumente
|
| 443 |
+
fold_output_dir = f"{experiment_name}"
|
| 444 |
+
training_args = TrainingArguments(
|
| 445 |
+
output_dir=fold_output_dir,
|
| 446 |
+
learning_rate=2e-5,
|
| 447 |
+
warmup_steps=500,
|
| 448 |
+
per_device_train_batch_size=32,
|
| 449 |
+
per_device_eval_batch_size=32,
|
| 450 |
+
num_train_epochs=20,
|
| 451 |
+
eval_strategy="steps",
|
| 452 |
+
eval_steps=40,
|
| 453 |
+
save_strategy="steps",
|
| 454 |
+
save_steps=40,
|
| 455 |
+
load_best_model_at_end=True,
|
| 456 |
+
metric_for_best_model="strict_f1",
|
| 457 |
+
greater_is_better=True,
|
| 458 |
+
logging_steps=10,
|
| 459 |
+
logging_strategy="steps",
|
| 460 |
+
report_to="all",
|
| 461 |
+
disable_tqdm=False,
|
| 462 |
+
seed=42,
|
| 463 |
+
save_total_limit=3,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Trainer
|
| 467 |
+
trainer = Trainer(
|
| 468 |
+
model=model,
|
| 469 |
+
args=training_args,
|
| 470 |
+
train_dataset=train_dataset,
|
| 471 |
+
eval_dataset=val_dataset,
|
| 472 |
+
data_collator=DataCollatorForTokenClassification(self.tokenizer),
|
| 473 |
+
compute_metrics=self.compute_metrics,
|
| 474 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=87)]
|
| 475 |
+
# 87 steps = 3.0 epochs with 29 steps per epoch
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Training
|
| 479 |
+
print(f"Training auf {len(train_dataset)} Beispielen")
|
| 480 |
+
print(f"Validation auf {len(val_dataset)} Beispielen")
|
| 481 |
+
trainer.train()
|
| 482 |
+
|
| 483 |
+
# Aktuelles Modell speichern
|
| 484 |
+
self.model = model
|
| 485 |
+
|
| 486 |
+
# Modell evaluieren auf Test-Daten
|
| 487 |
+
print(f"Evaluierung auf {len(test_comments)} Test-Beispielen")
|
| 488 |
+
metrics = self.evaluate_by_label(test_comments, spans_df)
|
| 489 |
+
wandb.log({
|
| 490 |
+
'strict_f1': metrics['TOTAL']['STRICT']['f1'],
|
| 491 |
+
'strict_precision': metrics['TOTAL']['STRICT']['prec'],
|
| 492 |
+
'strict_recall': metrics['TOTAL']['STRICT']['rec'],
|
| 493 |
+
'spans_f1': metrics['TOTAL']['SPANS']['f1'],
|
| 494 |
+
'types_f1': metrics['TOTAL']['TYPES']['f1']
|
| 495 |
+
})
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# Speichere Modell
|
| 499 |
+
torch.save(model.state_dict(), f'{fold_output_dir}_model.pth')
|
| 500 |
+
|
| 501 |
+
torch.cuda.memory.empty_cache()
|
| 502 |
+
wandb.finish()
|
| 503 |
+
|
| 504 |
+
return trainer
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def cross_validate(self, comments_df, spans_df, n_splits=5, output_dir_prefix="span-classifier-cv"):
|
| 508 |
+
"""Führe n-fache Kreuzvalidierung mit StratifiedKFold durch"""
|
| 509 |
+
|
| 510 |
+
# Erstelle Label für Stratifizierung (basierend auf dem ersten Span types eines Kommentars)
|
| 511 |
+
strat_labels = []
|
| 512 |
+
spans_grouped = spans_df.groupby(['document', 'comment_id'])
|
| 513 |
+
for _, row in comments_df.iterrows():
|
| 514 |
+
key = (row['document'], row['comment_id'])
|
| 515 |
+
# 1 wenn Kommentar Spans hat, sonst 0
|
| 516 |
+
has_spans = spans_grouped.get_group(key).iloc[0]['type'] if key in spans_grouped.groups and len(spans_grouped.get_group(key)) > 0 else 0
|
| 517 |
+
strat_labels.append(has_spans)
|
| 518 |
+
|
| 519 |
+
# Erstelle StratifiedKFold
|
| 520 |
+
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
| 521 |
+
|
| 522 |
+
# Speichere Metriken für jeden Fold
|
| 523 |
+
fold_metrics = []
|
| 524 |
+
|
| 525 |
+
# Iteriere über Folds
|
| 526 |
+
for fold, (train_idx, test_idx) in enumerate(skf.split(range(len(comments_df)), strat_labels)):
|
| 527 |
+
if '--fold' in sys.argv:
|
| 528 |
+
fold_arg = int(sys.argv[sys.argv.index('--fold') + 1])
|
| 529 |
+
if fold + 1 != fold_arg:
|
| 530 |
+
continue
|
| 531 |
+
|
| 532 |
+
wandb.init(project=os.environ["WANDB_PROJECT"], name=f"{experiment_name}-fold-{fold+1}",
|
| 533 |
+
group=experiment_name)
|
| 534 |
+
|
| 535 |
+
print(f"\n{'='*50}")
|
| 536 |
+
print(f"Fold {fold+1}/{n_splits}")
|
| 537 |
+
print(f"{'='*50}")
|
| 538 |
+
|
| 539 |
+
# Kommentare für diesen Fold
|
| 540 |
+
train_comments = comments_df.iloc[train_idx].reset_index(drop=True)
|
| 541 |
+
test_comments = comments_df.iloc[test_idx].reset_index(drop=True)
|
| 542 |
+
|
| 543 |
+
# Dataset neu erstellen für diesen Fold
|
| 544 |
+
examples, eval_data = self.create_dataset(train_comments, spans_df)
|
| 545 |
+
train_examples, val_examples = train_test_split(examples, test_size=0.1, random_state=42)
|
| 546 |
+
|
| 547 |
+
# Evaluation-Daten entsprechend aufteilen
|
| 548 |
+
train_indices, val_indices = train_test_split(range(len(examples)), test_size=0.1, random_state=42)
|
| 549 |
+
self.current_eval_data = [eval_data[i] for i in val_indices]
|
| 550 |
+
|
| 551 |
+
train_dataset = Dataset.from_list(train_examples)
|
| 552 |
+
val_dataset = Dataset.from_list(val_examples)
|
| 553 |
+
|
| 554 |
+
# Modell neu initialisieren
|
| 555 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 556 |
+
self.model_name,
|
| 557 |
+
num_labels=len(self.labels),
|
| 558 |
+
id2label=self.id2label,
|
| 559 |
+
label2id=self.label2id
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Training-Argumente
|
| 563 |
+
fold_output_dir = f"{output_dir_prefix}-fold-{fold+1}"
|
| 564 |
+
training_args = TrainingArguments(
|
| 565 |
+
output_dir=fold_output_dir,
|
| 566 |
+
learning_rate=2e-5,
|
| 567 |
+
warmup_steps=500,
|
| 568 |
+
per_device_train_batch_size=32,
|
| 569 |
+
per_device_eval_batch_size=32,
|
| 570 |
+
num_train_epochs=15,
|
| 571 |
+
eval_strategy="steps",
|
| 572 |
+
eval_steps=40,
|
| 573 |
+
save_strategy="steps",
|
| 574 |
+
save_steps=40,
|
| 575 |
+
load_best_model_at_end=True,
|
| 576 |
+
metric_for_best_model="strict_f1",
|
| 577 |
+
greater_is_better=True,
|
| 578 |
+
logging_steps=10,
|
| 579 |
+
logging_strategy="steps",
|
| 580 |
+
report_to="all",
|
| 581 |
+
disable_tqdm=False,
|
| 582 |
+
seed=42,
|
| 583 |
+
save_total_limit=3,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Trainer
|
| 587 |
+
trainer = Trainer(
|
| 588 |
+
model=model,
|
| 589 |
+
args=training_args,
|
| 590 |
+
train_dataset=train_dataset,
|
| 591 |
+
eval_dataset=val_dataset,
|
| 592 |
+
data_collator=DataCollatorForTokenClassification(self.tokenizer),
|
| 593 |
+
compute_metrics=self.compute_metrics,
|
| 594 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=87)] # 87 steps = 3.0 epochs with 29 steps per epoch
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Training
|
| 598 |
+
print(f"Training auf {len(train_dataset)} Beispielen")
|
| 599 |
+
print(f"Validation auf {len(val_dataset)} Beispielen")
|
| 600 |
+
trainer.train()
|
| 601 |
+
|
| 602 |
+
# Aktuelles Modell speichern
|
| 603 |
+
self.model = model
|
| 604 |
+
|
| 605 |
+
# Modell evaluieren auf Test-Daten
|
| 606 |
+
print(f"Evaluierung auf {len(test_comments)} Test-Beispielen")
|
| 607 |
+
flausch_results = self.evaluate_by_label(test_comments, spans_df)
|
| 608 |
+
|
| 609 |
+
# Extrahiere Hauptmetriken für fold_metrics
|
| 610 |
+
metrics = {
|
| 611 |
+
'strict_f1': flausch_results['TOTAL']['STRICT']['f1'],
|
| 612 |
+
'strict_precision': flausch_results['TOTAL']['STRICT']['prec'],
|
| 613 |
+
'strict_recall': flausch_results['TOTAL']['STRICT']['rec'],
|
| 614 |
+
'spans_f1': flausch_results['TOTAL']['SPANS']['f1'],
|
| 615 |
+
'spans_precision': flausch_results['TOTAL']['SPANS']['prec'],
|
| 616 |
+
'spans_recall': flausch_results['TOTAL']['SPANS']['rec'],
|
| 617 |
+
'types_f1': flausch_results['TOTAL']['TYPES']['f1'],
|
| 618 |
+
'types_precision': flausch_results['TOTAL']['TYPES']['prec'],
|
| 619 |
+
'types_recall': flausch_results['TOTAL']['TYPES']['rec'],
|
| 620 |
+
'full_results': flausch_results
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
fold_metrics.append(metrics)
|
| 624 |
+
wandb.log(metrics, step=fold + 1)
|
| 625 |
+
|
| 626 |
+
# Speichere Modell
|
| 627 |
+
torch.save(model.state_dict(), f'{fold_output_dir}_model.pth')
|
| 628 |
+
|
| 629 |
+
test_predictions = self.predict(test_comments['comment'].tolist())
|
| 630 |
+
|
| 631 |
+
# Speichere Metriken
|
| 632 |
+
with open(f"test_results.{experiment_name}.fold-{fold+1}.pkl", "wb") as p:
|
| 633 |
+
pickle.dump((train_comments, test_comments, test_predictions, train_examples, val_examples), p)
|
| 634 |
+
|
| 635 |
+
with open(f"scores.{experiment_name}.txt", 'a') as f:
|
| 636 |
+
f.write(f'[{time.strftime("%Y-%m-%d %H:%M:%S")}] Fold {fold+1} Ergebnisse:\n')
|
| 637 |
+
f.write(f"[{experiment_name} fold-{fold+1} {metrics}\n")
|
| 638 |
+
|
| 639 |
+
torch.cuda.memory.empty_cache()
|
| 640 |
+
wandb.finish()
|
| 641 |
+
|
| 642 |
+
# Zusammenfassung ausgeben
|
| 643 |
+
print("\n" + "="*50)
|
| 644 |
+
print("Kreuzvalidierung abgeschlossen")
|
| 645 |
+
print("="*50)
|
| 646 |
+
|
| 647 |
+
# Berechne Durchschnitts-Metriken
|
| 648 |
+
avg_f1 = np.mean([m['strict_f1'] for m in fold_metrics])
|
| 649 |
+
avg_precision = np.mean([m['strict_precision'] for m in fold_metrics])
|
| 650 |
+
avg_recall = np.mean([m['strict_recall'] for m in fold_metrics])
|
| 651 |
+
|
| 652 |
+
print(f"\nDurchschnittliche Metriken über {n_splits} Folds:")
|
| 653 |
+
print(f"Precision: {avg_precision:.10f}")
|
| 654 |
+
print(f"Recall: {avg_recall:.10f}")
|
| 655 |
+
print(f"F1-Score: {avg_f1:.10f}")
|
| 656 |
+
|
| 657 |
+
# Std-Abweichung
|
| 658 |
+
std_f1 = np.std([m['strict_f1'] for m in fold_metrics])
|
| 659 |
+
std_precision = np.std([m['strict_precision'] for m in fold_metrics])
|
| 660 |
+
std_recall = np.std([m['strict_recall'] for m in fold_metrics])
|
| 661 |
+
|
| 662 |
+
print(f"\nStandardabweichung über {n_splits} Folds:")
|
| 663 |
+
print(f"Precision: {std_precision:.10f}")
|
| 664 |
+
print(f"Recall: {std_recall:.10f}")
|
| 665 |
+
print(f"F1-Score: {std_f1:.10f}")
|
| 666 |
+
|
| 667 |
+
# Ergebnisse für jeden Fold ausgeben
|
| 668 |
+
for fold, metrics in enumerate(fold_metrics):
|
| 669 |
+
print(f"\nFold {fold+1} Ergebnisse:")
|
| 670 |
+
print(f"Precision: {metrics['strict_precision']:.4f}")
|
| 671 |
+
print(f"Recall: {metrics['strict_recall']:.4f}")
|
| 672 |
+
print(f"F1-Score: {metrics['strict_f1']:.4f}")
|
| 673 |
+
|
| 674 |
+
return {
|
| 675 |
+
'fold_metrics': fold_metrics,
|
| 676 |
+
'avg_metrics': {
|
| 677 |
+
'strict_f1': avg_f1,
|
| 678 |
+
'strict_precision': avg_precision,
|
| 679 |
+
'strict_recall': avg_recall
|
| 680 |
+
},
|
| 681 |
+
'std_metrics': {
|
| 682 |
+
'strict_f1': std_f1,
|
| 683 |
+
'strict_precision': std_precision,
|
| 684 |
+
'strict_recall': std_recall
|
| 685 |
+
}
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
# Daten laden
|
| 691 |
+
comments: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/training data/comments.csv")
|
| 692 |
+
task1: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/training data/task1.csv")
|
| 693 |
+
task2: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/training data/task2.csv")
|
| 694 |
+
comments = comments.merge(task1, on=["document", "comment_id"])
|
| 695 |
+
|
| 696 |
+
test_data: pd.DataFrame = pd.read_csv("../../share-GermEval2025-data/Data/test data/comments.csv")
|
| 697 |
+
|
| 698 |
+
# Wähle Teilmenge der Daten für Experiment (z.B. 17000 Kommentare)
|
| 699 |
+
experiment_data = comments
|
| 700 |
+
|
| 701 |
+
# Klassifikator mit Strict F1
|
| 702 |
+
classifier = SpanClassifierWithStrictF1('xlm-roberta-large')
|
| 703 |
+
|
| 704 |
+
# 5-fold Cross-Validation durchführen
|
| 705 |
+
#cv_results = classifier.cross_validate(
|
| 706 |
+
# experiment_data,
|
| 707 |
+
# task2,
|
| 708 |
+
# n_splits=5,
|
| 709 |
+
# output_dir_prefix=experiment_name
|
| 710 |
+
#)
|
| 711 |
+
#
|
| 712 |
+
## write results to text file
|
| 713 |
+
#with open(f"scores.{experiment_name}.txt", 'a') as f:
|
| 714 |
+
# f.write(f'[{time.strftime("%Y-%m-%d %H:%M:%S")}] KFold cross validation of {experiment_name}\n')
|
| 715 |
+
# f.write(f'{cv_results}\n')
|
| 716 |
+
|
| 717 |
+
# Optional: Finales Modell auf allen Daten trainieren
|
| 718 |
+
trainer = classifier.train(experiment_data, task2, f'{experiment_name}-final')
|
| 719 |
+
torch.save(classifier.model.state_dict(), f'{experiment_name}_final_model.pth')
|
| 720 |
+
|
| 721 |
+
# Test-Vorhersage mit finalem Modell
|
| 722 |
+
test_texts = ["Das ist ein toller Kommentar!", "Schlechter Text hier.",
|
| 723 |
+
"Sehr gutes Video. Danke! Ich finde Dich echt toll!", "Du bist doof!", "Das Licht ist echt gut.",
|
| 724 |
+
"Team Einhorn", "Macht unbedingt weiter so!", "Das sehe ich ganz genauso.", "Stimmt, Du hast vollkommen Recht!",
|
| 725 |
+
"Ich bin so dankbar ein #Lochinator zu sein"]
|
| 726 |
+
|
| 727 |
+
predictions = classifier.predict(test_texts)
|
| 728 |
+
|
| 729 |
+
for pred in predictions:
|
| 730 |
+
print(f"\nText: {pred['text']}")
|
| 731 |
+
for span in pred['spans']:
|
| 732 |
+
print(f" Span: '{span['text']}' ({span['start']}-{span['end']}) - {span['type']}")
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
|
subtask_2/submission_subtask2-2.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
subtask_2/submission_subtask2.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|