Update app.py
Browse files
app.py
CHANGED
|
@@ -124,6 +124,9 @@ def main():
|
|
| 124 |
length_penalty = st.sidebar.number_input(
|
| 125 |
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
|
| 126 |
)
|
|
|
|
|
|
|
|
|
|
| 127 |
st.sidebar.markdown(
|
| 128 |
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
|
| 129 |
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
|
|
@@ -132,6 +135,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
|
|
| 132 |
params = {
|
| 133 |
"num_beams": num_beams,
|
| 134 |
"num_beam_groups": num_beam_groups,
|
|
|
|
| 135 |
"length_penalty": length_penalty,
|
| 136 |
"early_stopping": True,
|
| 137 |
}
|
|
|
|
| 124 |
length_penalty = st.sidebar.number_input(
|
| 125 |
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
|
| 126 |
)
|
| 127 |
+
diversity_penalty = st.sidebar.number_input(
|
| 128 |
+
"Diversity penalty", min_value=0.0, max_value=2.0, value=0.0, step=0.1
|
| 129 |
+
)
|
| 130 |
st.sidebar.markdown(
|
| 131 |
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
|
| 132 |
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
|
|
|
|
| 135 |
params = {
|
| 136 |
"num_beams": num_beams,
|
| 137 |
"num_beam_groups": num_beam_groups,
|
| 138 |
+
"diversity_penalty": diversity_penalty if num_beam_groups > 1 else 0.0,
|
| 139 |
"length_penalty": length_penalty,
|
| 140 |
"early_stopping": True,
|
| 141 |
}
|