Spaces:
Sleeping
Sleeping
| import torch.nn.functional as F | |
| import pandas as pd | |
| from dataloader import create_dataloader | |
| from utils import * | |
| def predict_segmentation(inp, model, device, batch_size=8): | |
| test_loader = create_dataloader(inp, batch_size) | |
| predictions = [] | |
| for batch in test_loader: | |
| batch = {k: v.to(device) for k, v in batch.items()} | |
| p = F.sigmoid(model(**batch).logits).detach().cpu().numpy() | |
| predictions.append(p) | |
| return np.concatenate(predictions, axis=0) | |
| def create_data(text, tokenizer, seq_len=512): | |
| tokens = tokenizer(text, add_special_tokens=False) | |
| _token_batches = {k: [pad_seq(x, seq_len) for x in batch_list(v, seq_len)] | |
| for (k, v) in tokens.items()} | |
| n_batches = len(_token_batches['input_ids']) | |
| return [{k: v[i] for k, v in _token_batches.items()} | |
| for i in range(n_batches)] | |
| def segment_tokens(notes, model, tokenizer, device, batch_size=8): | |
| predictions = {} | |
| for note in notes.itertuples(): | |
| note_id = note.note_id | |
| raw_text = note.text.lower() | |
| inp = create_data(raw_text, tokenizer) | |
| pred_probs = predict_segmentation(inp, model, device, batch_size=batch_size) | |
| pred_probs = np.squeeze(pred_probs, -1) | |
| pred_probs = np.concatenate(pred_probs) | |
| predictions[note_id] = pred_probs | |
| return predictions | |
| def segment(notes, model, tokenizer, device, thresh, batch_size=8): | |
| predictions = [] | |
| predictions_prob_map = segment_tokens(notes, model, tokenizer, device, batch_size) | |
| for note in notes.itertuples(): | |
| note_id = note.note_id | |
| raw_text = note.text | |
| decoded_text = tokenizer.decode(tokenizer.encode(raw_text, add_special_tokens=False)) | |
| pred_probs = predictions_prob_map[note_id] | |
| _, pred_probs = align_decoded(raw_text, decoded_text, pred_probs) | |
| pred_probs = np.array(pred_probs, 'float32') | |
| pred = (pred_probs > thresh).astype('uint8') | |
| spans = get_sequential_spans(pred) | |
| note_predictions = {'note_id': [], 'start': [], 'end': [], 'mention': [], 'score': []} | |
| for (start, end) in spans: | |
| note_predictions['note_id'].append(note_id) | |
| note_predictions['score'].append(pred_probs[start:end].mean()) | |
| note_predictions['start'].append(start) | |
| note_predictions['end'].append(end) | |
| note_predictions['mention'].append(raw_text[start:end]) | |
| note_predictions = pd.DataFrame(note_predictions) | |
| note_predictions = note_predictions.sort_values('score', ascending=False) | |
| # remove overlapping spans | |
| seen_spans = set() | |
| unseen = [] | |
| for span in note_predictions[['start', 'end']].values: | |
| span = tuple(span) | |
| s = False | |
| if not is_overlap(seen_spans, span): | |
| seen_spans.add(span) | |
| s = True | |
| unseen.append(s) | |
| note_predictions = note_predictions[unseen] | |
| predictions.append(note_predictions) | |
| predictions = pd.concat(predictions).reset_index(drop=True) | |
| return predictions | |