Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
·
ad7fc61
1
Parent(s):
a6de017
Use classifier category if transformer generates unknown category
Browse files- src/predict.py +6 -4
src/predict.py
CHANGED
|
@@ -106,7 +106,7 @@ class ClassifierArguments:
|
|
| 106 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
| 107 |
|
| 108 |
|
| 109 |
-
def
|
| 110 |
"""Use classifier to filter predictions"""
|
| 111 |
if not predictions:
|
| 112 |
return predictions
|
|
@@ -134,8 +134,10 @@ def add_predictions(predictions, classifier_args): # classifier, vectorizer,
|
|
| 134 |
if classifier_category is None and classifier_probability > classifier_args.min_probability:
|
| 135 |
continue # Ignore
|
| 136 |
|
| 137 |
-
if
|
| 138 |
-
|
|
|
|
|
|
|
| 139 |
prediction['category'] = classifier_category
|
| 140 |
|
| 141 |
prediction['probability'] = predicted_probabilities[prediction['category']]
|
|
@@ -173,7 +175,7 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
| 173 |
|
| 174 |
# TODO add back
|
| 175 |
if classifier_args is not None:
|
| 176 |
-
predictions =
|
| 177 |
|
| 178 |
return predictions
|
| 179 |
|
|
|
|
| 106 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
| 107 |
|
| 108 |
|
| 109 |
+
def filter_and_add_probabilities(predictions, classifier_args): # classifier, vectorizer,
|
| 110 |
"""Use classifier to filter predictions"""
|
| 111 |
if not predictions:
|
| 112 |
return predictions
|
|
|
|
| 134 |
if classifier_category is None and classifier_probability > classifier_args.min_probability:
|
| 135 |
continue # Ignore
|
| 136 |
|
| 137 |
+
if (prediction['category'] not in predicted_probabilities) \
|
| 138 |
+
or (classifier_category is not None and classifier_probability > 0.5): # TODO make param
|
| 139 |
+
# Unknown category or we are confident enough to overrule,
|
| 140 |
+
# so change category to what was predicted by classifier
|
| 141 |
prediction['category'] = classifier_category
|
| 142 |
|
| 143 |
prediction['probability'] = predicted_probabilities[prediction['category']]
|
|
|
|
| 175 |
|
| 176 |
# TODO add back
|
| 177 |
if classifier_args is not None:
|
| 178 |
+
predictions = filter_and_add_probabilities(predictions, classifier_args)
|
| 179 |
|
| 180 |
return predictions
|
| 181 |
|