Spaces:
Running on CPU Upgrade

File size: 3,011 Bytes
bb3d05e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbe7bbd
 
 
 
bb3d05e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a16c24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
from transformers import AutoModel

###################################################################################

# Erweiterte Regressorklasse: Ein gemeinsamer Encoder, aber mehrere unabhängige Köpfe
class BertMultiHeadRegressor(nn.Module):
    """
    Mehrkopf-Regression auf einem beliebigen HF-Encoder (BERT/RoBERTa/DeBERTa/ModernBERT).
    - Gemeinsamer Encoder
    - n unabhängige Regressionsköpfe (je 1 Wert)
    - Robustes Pooling (Pooler wenn vorhanden, sonst maskiertes Mean)
    - Partielles Unfreezen ab `unfreeze_from`
    """
    def __init__(self, pretrained_model_name: str,
                 n_heads: int = 8,
                 unfreeze_from: int = 8,
                 dropout: float = 0.1):
        super().__init__()

        # Beliebigen Encoder laden
        self.encoder = AutoModel.from_pretrained(
            pretrained_model_name,
            low_cpu_mem_usage=False  # vermeidet accelerate-Abhängigkeit zur Init
        )
        hidden_size = self.encoder.config.hidden_size

        # Erst alles einfrieren …
        for p in self.encoder.parameters():
            p.requires_grad = False

        # … dann Layer ab `unfreeze_from` freigeben (falls vorhanden)
        # Die meisten Encoder haben `.encoder.layer`
        encoder_block = getattr(self.encoder, "encoder", None)
        layers = getattr(encoder_block, "layer", None)
        if layers is not None:
            for layer in layers[unfreeze_from:]:
                for p in layer.parameters():
                    p.requires_grad = True
        else:
            # Fallback: wenn kein klassisches Lagen-Array existiert, nichts tun
            pass

        self.dropout = nn.Dropout(dropout)
        self.heads = nn.ModuleList([nn.Linear(hidden_size, 1) for _ in range(n_heads)])

    def _pool(self, outputs, attention_mask):
        """
        Robustes Pooling:
        - Wenn pooler_output vorhanden: nutzen (BERT/RoBERTa)
        - Sonst: maskiertes Mean-Pooling über last_hidden_state (z. B. DeBERTaV3)
        """
        pooler = getattr(outputs, "pooler_output", None)
        if pooler is not None:
            return pooler  # [B, H]

        last_hidden = outputs.last_hidden_state  # [B, T, H]
        mask = attention_mask.unsqueeze(-1).float()  # [B, T, 1]
        summed = (last_hidden * mask).sum(dim=1)     # [B, H]
        denom = mask.sum(dim=1).clamp(min=1e-6)      # [B, 1]
        return summed / denom

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids if token_type_ids is not None else None,
            return_dict=True
        )
        pooled = self._pool(outputs, attention_mask)    # [B, H]
        pooled = self.dropout(pooled)
        preds = [head(pooled) for head in self.heads]   # n × [B, 1]
        return torch.cat(preds, dim=1)                  # [B, n_heads]