Maharshi Gor
commited on
Commit
·
849566b
1
Parent(s):
c1ae336
Add user input validation to pipeline interfaces error display on pipeline change.
Browse files- src/components/model_pipeline/model_pipeline.py +13 -6
- src/components/model_pipeline/state_manager.py +120 -42
- src/components/model_pipeline/tossup_pipeline.py +4 -3
- src/components/quizbowl/bonus.py +2 -0
- src/components/quizbowl/tossup.py +3 -1
- src/components/quizbowl/validation.py +30 -0
- src/components/structs.py +16 -0
- src/workflows/validators.py +27 -23
src/components/model_pipeline/model_pipeline.py
CHANGED
|
@@ -10,13 +10,14 @@ from components.model_pipeline.state_manager import (
|
|
| 10 |
PipelineState,
|
| 11 |
PipelineStateManager,
|
| 12 |
PipelineUIState,
|
|
|
|
| 13 |
TossupPipelineState,
|
| 14 |
TossupPipelineStateManager,
|
| 15 |
)
|
| 16 |
from components.model_step.model_step import ModelStepComponent
|
| 17 |
from components.utils import make_state
|
| 18 |
from workflows.structs import ModelStep, TossupWorkflow, Workflow
|
| 19 |
-
from workflows.validators import WorkflowValidator
|
| 20 |
|
| 21 |
from .state_manager import get_output_panel_state
|
| 22 |
|
|
@@ -33,6 +34,7 @@ class PipelineInterface:
|
|
| 33 |
ui_state: PipelineUIState | None = None,
|
| 34 |
model_options: list[str] = None,
|
| 35 |
config: dict = {},
|
|
|
|
| 36 |
):
|
| 37 |
self.app = app
|
| 38 |
self.model_options = model_options
|
|
@@ -50,10 +52,10 @@ class PipelineInterface:
|
|
| 50 |
|
| 51 |
if isinstance(workflow, TossupWorkflow):
|
| 52 |
pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
|
| 53 |
-
self.sm = TossupPipelineStateManager()
|
| 54 |
else:
|
| 55 |
pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
|
| 56 |
-
self.sm = PipelineStateManager()
|
| 57 |
self.pipeline_state = make_state(pipeline_state.model_dump())
|
| 58 |
|
| 59 |
def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
|
|
@@ -169,7 +171,11 @@ class PipelineInterface:
|
|
| 169 |
"""Validate the workflow."""
|
| 170 |
try:
|
| 171 |
state = self.sm.make_pipeline_state(state_dict)
|
| 172 |
-
WorkflowValidator(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
except ValueError as e:
|
| 174 |
logger.exception(e)
|
| 175 |
state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
|
|
@@ -244,6 +250,7 @@ class PipelineInterface:
|
|
| 244 |
def _render_pipeline_preview(self):
|
| 245 |
export_btn = gr.Button("Export Pipeline", elem_classes="export-button", visible=False)
|
| 246 |
# components.append(export_btn)
|
|
|
|
| 247 |
|
| 248 |
# Add a code box to display the workflow JSON
|
| 249 |
# with gr.Column(elem_classes="workflow-json-container"):
|
|
@@ -262,7 +269,7 @@ class PipelineInterface:
|
|
| 262 |
self.config_output.blur(
|
| 263 |
fn=self.sm.update_workflow_from_code,
|
| 264 |
inputs=[self.config_output, self.pipeline_change],
|
| 265 |
-
outputs=[self.pipeline_state, self.pipeline_change],
|
| 266 |
)
|
| 267 |
|
| 268 |
# Connect the export button to show the workflow JSON
|
|
@@ -326,6 +333,6 @@ class PipelineInterface:
|
|
| 326 |
).success(
|
| 327 |
fn=self.sm.get_formatted_config,
|
| 328 |
inputs=[self.pipeline_state, gr.State("yaml")],
|
| 329 |
-
outputs=[self.config_output],
|
| 330 |
js=js,
|
| 331 |
)
|
|
|
|
| 10 |
PipelineState,
|
| 11 |
PipelineStateManager,
|
| 12 |
PipelineUIState,
|
| 13 |
+
PipelineValidator,
|
| 14 |
TossupPipelineState,
|
| 15 |
TossupPipelineStateManager,
|
| 16 |
)
|
| 17 |
from components.model_step.model_step import ModelStepComponent
|
| 18 |
from components.utils import make_state
|
| 19 |
from workflows.structs import ModelStep, TossupWorkflow, Workflow
|
| 20 |
+
from workflows.validators import WorkflowValidationError, WorkflowValidator
|
| 21 |
|
| 22 |
from .state_manager import get_output_panel_state
|
| 23 |
|
|
|
|
| 34 |
ui_state: PipelineUIState | None = None,
|
| 35 |
model_options: list[str] = None,
|
| 36 |
config: dict = {},
|
| 37 |
+
validator: PipelineValidator | None = None,
|
| 38 |
):
|
| 39 |
self.app = app
|
| 40 |
self.model_options = model_options
|
|
|
|
| 52 |
|
| 53 |
if isinstance(workflow, TossupWorkflow):
|
| 54 |
pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
|
| 55 |
+
self.sm = TossupPipelineStateManager(validator)
|
| 56 |
else:
|
| 57 |
pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
|
| 58 |
+
self.sm = PipelineStateManager(validator)
|
| 59 |
self.pipeline_state = make_state(pipeline_state.model_dump())
|
| 60 |
|
| 61 |
def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
|
|
|
|
| 171 |
"""Validate the workflow."""
|
| 172 |
try:
|
| 173 |
state = self.sm.make_pipeline_state(state_dict)
|
| 174 |
+
validator = WorkflowValidator(
|
| 175 |
+
max_temperature=self.config.get("max_temperature", 10),
|
| 176 |
+
)
|
| 177 |
+
if not validator.validate(state.workflow):
|
| 178 |
+
raise WorkflowValidationError(validator.errors)
|
| 179 |
except ValueError as e:
|
| 180 |
logger.exception(e)
|
| 181 |
state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
|
|
|
|
| 250 |
def _render_pipeline_preview(self):
|
| 251 |
export_btn = gr.Button("Export Pipeline", elem_classes="export-button", visible=False)
|
| 252 |
# components.append(export_btn)
|
| 253 |
+
self.error_display = gr.HTML(label="Error", elem_id="pipeline-preview-error-display", visible=False)
|
| 254 |
|
| 255 |
# Add a code box to display the workflow JSON
|
| 256 |
# with gr.Column(elem_classes="workflow-json-container"):
|
|
|
|
| 269 |
self.config_output.blur(
|
| 270 |
fn=self.sm.update_workflow_from_code,
|
| 271 |
inputs=[self.config_output, self.pipeline_change],
|
| 272 |
+
outputs=[self.pipeline_state, self.pipeline_change, self.error_display],
|
| 273 |
)
|
| 274 |
|
| 275 |
# Connect the export button to show the workflow JSON
|
|
|
|
| 333 |
).success(
|
| 334 |
fn=self.sm.get_formatted_config,
|
| 335 |
inputs=[self.pipeline_state, gr.State("yaml")],
|
| 336 |
+
outputs=[self.config_output, self.error_display],
|
| 337 |
js=js,
|
| 338 |
)
|
src/components/model_pipeline/state_manager.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
from typing import Literal
|
| 3 |
|
|
|
|
| 4 |
import yaml
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from app_configs import UNSELECTED_VAR_NAME
|
| 7 |
from components import typed_dicts as td
|
|
@@ -22,24 +27,49 @@ def get_output_panel_state(workflow: Workflow) -> dict:
|
|
| 22 |
return state
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
class PipelineStateManager:
|
| 26 |
"""Manages a pipeline of multiple steps."""
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState:
|
| 29 |
"""Make a state from a state dictionary."""
|
| 30 |
-
return
|
| 31 |
-
|
| 32 |
-
def get_formatted_config(self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml") -> str:
|
| 33 |
-
"""Get the full pipeline configuration."""
|
| 34 |
-
state = self.make_pipeline_state(state_dict)
|
| 35 |
-
config = state.workflow.model_dump(exclude_defaults=True)
|
| 36 |
-
if isinstance(state.workflow, TossupWorkflow):
|
| 37 |
-
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
|
| 38 |
-
config["buzzer"] = buzzer_config
|
| 39 |
-
if format == "yaml":
|
| 40 |
-
return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
|
| 41 |
-
else:
|
| 42 |
-
return json.dumps(config, indent=4, sort_keys=False)
|
| 43 |
|
| 44 |
def add_step(
|
| 45 |
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
|
|
@@ -102,7 +132,7 @@ class PipelineStateManager:
|
|
| 102 |
produced_variable = None
|
| 103 |
"""Update the output variables for a step."""
|
| 104 |
state = self.make_pipeline_state(state_dict)
|
| 105 |
-
state.
|
| 106 |
return state.model_dump()
|
| 107 |
|
| 108 |
def update_model_step_ui(
|
|
@@ -117,53 +147,101 @@ class PipelineStateManager:
|
|
| 117 |
"""Get all variables from all steps."""
|
| 118 |
return self.make_pipeline_state(state_dict)
|
| 119 |
|
| 120 |
-
def parse_yaml_workflow(self, yaml_str: str) -> Workflow:
|
| 121 |
"""Parse a YAML workflow."""
|
| 122 |
workflow = yaml.safe_load(yaml_str)
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
"""Update a workflow from a YAML string."""
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
class TossupPipelineStateManager(PipelineStateManager):
|
| 132 |
"""Manages a tossup pipeline state."""
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
return TossupPipelineState(**state_dict)
|
| 137 |
|
| 138 |
-
def
|
| 139 |
-
|
| 140 |
-
workflow = yaml.safe_load(yaml_str)
|
| 141 |
-
return TossupWorkflow(**workflow)
|
| 142 |
|
| 143 |
-
def update_workflow_from_code(
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
return
|
| 147 |
|
| 148 |
def update_model_step_state(
|
| 149 |
self, state_dict: td.TossupPipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState
|
| 150 |
) -> td.TossupPipelineStateDict:
|
| 151 |
-
|
| 152 |
-
state = self.make_pipeline_state(state_dict)
|
| 153 |
-
state = state.update_step(model_step, ui_state)
|
| 154 |
-
state.workflow = state.workflow.refresh_buzzer()
|
| 155 |
-
return state.model_dump()
|
| 156 |
|
| 157 |
def update_output_variables(
|
| 158 |
self, state_dict: td.TossupPipelineStateDict, target: str, produced_variable: str
|
| 159 |
) -> td.TossupPipelineStateDict:
|
| 160 |
-
|
| 161 |
-
produced_variable = None
|
| 162 |
-
"""Update the output variables for a step."""
|
| 163 |
-
state = self.make_pipeline_state(state_dict)
|
| 164 |
-
state.workflow.outputs[target] = produced_variable
|
| 165 |
-
state.workflow = state.workflow.refresh_buzzer()
|
| 166 |
-
return state.model_dump()
|
| 167 |
|
| 168 |
def update_buzzer(
|
| 169 |
self,
|
|
|
|
| 1 |
+
# %%
|
| 2 |
import json
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
from typing import Literal
|
| 5 |
|
| 6 |
+
import gradio as gr
|
| 7 |
import yaml
|
| 8 |
+
from loguru import logger
|
| 9 |
+
from pydantic import BaseModel, ValidationError
|
| 10 |
|
| 11 |
from app_configs import UNSELECTED_VAR_NAME
|
| 12 |
from components import typed_dicts as td
|
|
|
|
| 27 |
return state
|
| 28 |
|
| 29 |
|
| 30 |
+
def strict_model_validate(model_cls: type[BaseModel], data: dict):
|
| 31 |
+
# Dynamically create a subclass with extra='forbid'
|
| 32 |
+
class_name = model_cls.__name__
|
| 33 |
+
strict_class_name = f"Strict{class_name}"
|
| 34 |
+
|
| 35 |
+
strict_class = type(
|
| 36 |
+
strict_class_name,
|
| 37 |
+
(model_cls,),
|
| 38 |
+
{"model_config": {**getattr(model_cls, "model_config", {}), "extra": "forbid"}},
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return strict_class.model_validate(data)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class PipelineValidator(ABC):
|
| 45 |
+
"""Abstract base class for pipeline validators."""
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def __call__(self, workflow: Workflow):
|
| 49 |
+
"""
|
| 50 |
+
Validate the workflow.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
workflow: The workflow to validate.
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
ValueError: If the workflow is invalid.
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
class PipelineStateManager:
|
| 62 |
"""Manages a pipeline of multiple steps."""
|
| 63 |
|
| 64 |
+
pipeline_state_cls = PipelineState
|
| 65 |
+
workflow_cls = Workflow
|
| 66 |
+
|
| 67 |
+
def __init__(self, validator: PipelineValidator | None = None):
|
| 68 |
+
self.validator = validator
|
| 69 |
+
|
| 70 |
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState:
|
| 71 |
"""Make a state from a state dictionary."""
|
| 72 |
+
return self.pipeline_state_cls(**state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def add_step(
|
| 75 |
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
|
|
|
|
| 132 |
produced_variable = None
|
| 133 |
"""Update the output variables for a step."""
|
| 134 |
state = self.make_pipeline_state(state_dict)
|
| 135 |
+
state = state.update_output_variable(target, produced_variable)
|
| 136 |
return state.model_dump()
|
| 137 |
|
| 138 |
def update_model_step_ui(
|
|
|
|
| 147 |
"""Get all variables from all steps."""
|
| 148 |
return self.make_pipeline_state(state_dict)
|
| 149 |
|
| 150 |
+
def parse_yaml_workflow(self, yaml_str: str, strict: bool = True) -> Workflow:
|
| 151 |
"""Parse a YAML workflow."""
|
| 152 |
workflow = yaml.safe_load(yaml_str)
|
| 153 |
+
try:
|
| 154 |
+
if strict:
|
| 155 |
+
return strict_model_validate(self.workflow_cls, workflow)
|
| 156 |
+
else:
|
| 157 |
+
return self.workflow_cls.model_validate(workflow)
|
| 158 |
+
except ValidationError as e:
|
| 159 |
+
new_exception = ValidationError.from_exception_data(
|
| 160 |
+
e.title.removeprefix("Strict"), e.errors(), input_type="json"
|
| 161 |
+
)
|
| 162 |
+
raise new_exception from e
|
| 163 |
+
|
| 164 |
+
def _handle_pipeline_parsing_error(self, e: Exception) -> str:
|
| 165 |
+
"""Format error messages for pipeline parsing errors with consistent styling."""
|
| 166 |
+
error_template = """
|
| 167 |
+
<div class="md" style='color: #FF0000; background-color: #FFEEEE; padding: 10px; border-radius: 5px; border-left: 4px solid #FF0000;'>
|
| 168 |
+
<strong style='color: #FF0000;'>{error_type}:</strong> <br>
|
| 169 |
+
<div class="code-wrap">
|
| 170 |
+
<pre><code>{error_message}</code></pre>
|
| 171 |
+
</div>
|
| 172 |
+
{help_text}
|
| 173 |
+
</div>
|
| 174 |
+
"""
|
| 175 |
+
logger.exception(e)
|
| 176 |
+
if isinstance(e, yaml.YAMLError):
|
| 177 |
+
error_type = "Invalid YAML Error"
|
| 178 |
+
help_text = "Refer to the <a href='https://spacelift.io/blog/yaml#basic-yaml-syntax' target='_blank'>YAML schema</a> for correct formatting."
|
| 179 |
+
elif isinstance(e, ValidationError):
|
| 180 |
+
error_type = "Pipeline Parsing Error"
|
| 181 |
+
help_text = "Refer to the <a href='https://mgor.info' target='_blank'>documentation</a> for the correct pipeline schema."
|
| 182 |
+
elif isinstance(e, ValueError):
|
| 183 |
+
error_type = "Pipeline Validation Error"
|
| 184 |
+
help_text = "Refer to the <a href='https://mgor.info' target='_blank'>documentation</a> for the correct pipeline schema."
|
| 185 |
+
else:
|
| 186 |
+
error_type = "Unexpected Error"
|
| 187 |
+
help_text = "Please report this issue to us at <a href='https://github.com/maharshi95/QANTA25/issues' target='_blank'>GitHub Issues</a>."
|
| 188 |
+
|
| 189 |
+
return error_template.format(error_type=error_type, error_message=str(e), help_text=help_text)
|
| 190 |
|
| 191 |
+
def get_formatted_config(
|
| 192 |
+
self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml"
|
| 193 |
+
) -> tuple[str, dict]:
|
| 194 |
+
"""Get the full pipeline configuration."""
|
| 195 |
+
try:
|
| 196 |
+
state = self.make_pipeline_state(state_dict)
|
| 197 |
+
config = state.workflow.model_dump(exclude_defaults=True)
|
| 198 |
+
if isinstance(state.workflow, TossupWorkflow):
|
| 199 |
+
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
|
| 200 |
+
config["buzzer"] = buzzer_config
|
| 201 |
+
if format == "yaml":
|
| 202 |
+
config_str = yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
|
| 203 |
+
else:
|
| 204 |
+
config_str = json.dumps(config, indent=4, sort_keys=False)
|
| 205 |
+
return config_str, gr.update(visible=False)
|
| 206 |
+
except Exception as e:
|
| 207 |
+
error_message = self._handle_pipeline_parsing_error(e)
|
| 208 |
+
return gr.skip(), gr.update(value=error_message, visible=True)
|
| 209 |
+
|
| 210 |
+
def update_workflow_from_code(self, yaml_str: str, change_state: bool) -> tuple[td.PipelineStateDict, bool, dict]:
|
| 211 |
"""Update a workflow from a YAML string."""
|
| 212 |
+
try:
|
| 213 |
+
workflow = self.parse_yaml_workflow(yaml_str, strict=True)
|
| 214 |
+
self.validator and self.validator(workflow)
|
| 215 |
+
state = self.pipeline_state_cls.from_workflow(workflow)
|
| 216 |
+
return state.model_dump(), not change_state, gr.update(visible=False)
|
| 217 |
+
except Exception as e:
|
| 218 |
+
error_message = self._handle_pipeline_parsing_error(e)
|
| 219 |
+
return gr.skip(), gr.skip(), gr.update(value=error_message, visible=True)
|
| 220 |
|
| 221 |
|
| 222 |
class TossupPipelineStateManager(PipelineStateManager):
|
| 223 |
"""Manages a tossup pipeline state."""
|
| 224 |
|
| 225 |
+
pipeline_state_cls = TossupPipelineState
|
| 226 |
+
workflow_cls = TossupWorkflow
|
|
|
|
| 227 |
|
| 228 |
+
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
|
| 229 |
+
return super().make_pipeline_state(state_dict)
|
|
|
|
|
|
|
| 230 |
|
| 231 |
+
def update_workflow_from_code(
|
| 232 |
+
self, yaml_str: str, change_state: bool
|
| 233 |
+
) -> tuple[td.TossupPipelineStateDict, bool, dict]:
|
| 234 |
+
return super().update_workflow_from_code(yaml_str, change_state)
|
| 235 |
|
| 236 |
def update_model_step_state(
|
| 237 |
self, state_dict: td.TossupPipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState
|
| 238 |
) -> td.TossupPipelineStateDict:
|
| 239 |
+
return super().update_model_step_state(state_dict, model_step, ui_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
def update_output_variables(
|
| 242 |
self, state_dict: td.TossupPipelineStateDict, target: str, produced_variable: str
|
| 243 |
) -> td.TossupPipelineStateDict:
|
| 244 |
+
return super().update_output_variables(state_dict, target, produced_variable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
def update_buzzer(
|
| 247 |
self,
|
src/components/model_pipeline/tossup_pipeline.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import numpy as np
|
| 3 |
from loguru import logger
|
| 4 |
|
| 5 |
from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
|
|
@@ -9,7 +8,8 @@ from components.typed_dicts import TossupPipelineStateDict
|
|
| 9 |
from display.formatting import tiny_styled_warning
|
| 10 |
from workflows.structs import Buzzer, TossupWorkflow
|
| 11 |
|
| 12 |
-
from .model_pipeline import PipelineInterface
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def toggleable_slider(
|
|
@@ -40,8 +40,9 @@ class TossupPipelineInterface(PipelineInterface):
|
|
| 40 |
ui_state: PipelineUIState | None = None,
|
| 41 |
model_options: list[str] = None,
|
| 42 |
config: dict = {},
|
|
|
|
| 43 |
):
|
| 44 |
-
super().__init__(app, workflow, ui_state, model_options, config)
|
| 45 |
|
| 46 |
self.buzzer_state = gr.State(workflow.buzzer.model_dump())
|
| 47 |
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
from loguru import logger
|
| 3 |
|
| 4 |
from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
|
|
|
|
| 8 |
from display.formatting import tiny_styled_warning
|
| 9 |
from workflows.structs import Buzzer, TossupWorkflow
|
| 10 |
|
| 11 |
+
from .model_pipeline import PipelineInterface
|
| 12 |
+
from .state_manager import PipelineUIState, PipelineValidator
|
| 13 |
|
| 14 |
|
| 15 |
def toggleable_slider(
|
|
|
|
| 40 |
ui_state: PipelineUIState | None = None,
|
| 41 |
model_options: list[str] = None,
|
| 42 |
config: dict = {},
|
| 43 |
+
validator: PipelineValidator | None = None,
|
| 44 |
):
|
| 45 |
+
super().__init__(app, workflow, ui_state, model_options, config, validator)
|
| 46 |
|
| 47 |
self.buzzer_state = gr.State(workflow.buzzer.model_dump())
|
| 48 |
|
src/components/quizbowl/bonus.py
CHANGED
|
@@ -19,6 +19,7 @@ from workflows.qb_agents import QuizBowlBonusAgent
|
|
| 19 |
from . import populate, validation
|
| 20 |
from .plotting import create_bonus_confidence_plot, create_bonus_html
|
| 21 |
from .utils import evaluate_prediction
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def process_bonus_results(results: list[dict]) -> pd.DataFrame:
|
|
@@ -105,6 +106,7 @@ class BonusInterface:
|
|
| 105 |
ui_state=pipeline_state.ui_state,
|
| 106 |
model_options=list(self.model_options.keys()),
|
| 107 |
config=self.defaults,
|
|
|
|
| 108 |
)
|
| 109 |
|
| 110 |
def _render_qb_interface(self):
|
|
|
|
| 19 |
from . import populate, validation
|
| 20 |
from .plotting import create_bonus_confidence_plot, create_bonus_html
|
| 21 |
from .utils import evaluate_prediction
|
| 22 |
+
from .validation import UserInputWorkflowValidator
|
| 23 |
|
| 24 |
|
| 25 |
def process_bonus_results(results: list[dict]) -> pd.DataFrame:
|
|
|
|
| 106 |
ui_state=pipeline_state.ui_state,
|
| 107 |
model_options=list(self.model_options.keys()),
|
| 108 |
config=self.defaults,
|
| 109 |
+
validator=UserInputWorkflowValidator("bonus"),
|
| 110 |
)
|
| 111 |
|
| 112 |
def _render_qb_interface(self):
|
src/components/quizbowl/tossup.py
CHANGED
|
@@ -25,6 +25,7 @@ from .plotting import (
|
|
| 25 |
prepare_tossup_results_df,
|
| 26 |
)
|
| 27 |
from .utils import evaluate_prediction
|
|
|
|
| 28 |
|
| 29 |
# TODO: Error handling on run tossup and evaluate tossup and show correct messages
|
| 30 |
# TODO: ^^ Same for Bonus
|
|
@@ -135,7 +136,7 @@ class TossupInterface:
|
|
| 135 |
self.output_state = gr.State(value={})
|
| 136 |
self.render()
|
| 137 |
|
| 138 |
-
# ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE
|
| 139 |
|
| 140 |
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
|
| 141 |
logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
|
|
@@ -165,6 +166,7 @@ class TossupInterface:
|
|
| 165 |
ui_state=pipeline_state.ui_state,
|
| 166 |
model_options=list(self.model_options.keys()),
|
| 167 |
config=self.defaults,
|
|
|
|
| 168 |
)
|
| 169 |
|
| 170 |
def _render_qb_interface(self):
|
|
|
|
| 25 |
prepare_tossup_results_df,
|
| 26 |
)
|
| 27 |
from .utils import evaluate_prediction
|
| 28 |
+
from .validation import UserInputWorkflowValidator
|
| 29 |
|
| 30 |
# TODO: Error handling on run tossup and evaluate tossup and show correct messages
|
| 31 |
# TODO: ^^ Same for Bonus
|
|
|
|
| 136 |
self.output_state = gr.State(value={})
|
| 137 |
self.render()
|
| 138 |
|
| 139 |
+
# ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE ------------------------------------
|
| 140 |
|
| 141 |
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
|
| 142 |
logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
|
|
|
|
| 166 |
ui_state=pipeline_state.ui_state,
|
| 167 |
model_options=list(self.model_options.keys()),
|
| 168 |
config=self.defaults,
|
| 169 |
+
validator=UserInputWorkflowValidator("tossup"),
|
| 170 |
)
|
| 171 |
|
| 172 |
def _render_qb_interface(self):
|
src/components/quizbowl/validation.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from app_configs import CONFIGS
|
| 2 |
from components.structs import PipelineState, TossupPipelineState
|
| 3 |
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
|
|
@@ -53,3 +55,31 @@ def validate_bonus_workflow(pipeline_state_dict: PipelineStateDict):
|
|
| 53 |
CONFIGS["bonus"]["required_output_vars"],
|
| 54 |
)
|
| 55 |
return pipeline_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
|
| 3 |
from app_configs import CONFIGS
|
| 4 |
from components.structs import PipelineState, TossupPipelineState
|
| 5 |
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
|
|
|
|
| 55 |
CONFIGS["bonus"]["required_output_vars"],
|
| 56 |
)
|
| 57 |
return pipeline_state
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class UserInputWorkflowValidator:
|
| 61 |
+
def __init__(self, mode: Literal["tossup", "bonus"]):
|
| 62 |
+
self.required_input_vars = CONFIGS[mode]["required_input_vars"]
|
| 63 |
+
self.required_output_vars = CONFIGS[mode]["required_output_vars"]
|
| 64 |
+
|
| 65 |
+
def __call__(self, workflow: TossupWorkflow):
|
| 66 |
+
input_vars = set(workflow.inputs)
|
| 67 |
+
for req_var in self.required_input_vars:
|
| 68 |
+
if req_var not in input_vars:
|
| 69 |
+
default_str = "inputs:\n" + "\n".join([f"- {var}" for var in self.required_input_vars])
|
| 70 |
+
raise ValueError(
|
| 71 |
+
f"Missing required input variable: '{req_var}'. "
|
| 72 |
+
"\nDon't modify the 'inputs' field in the workflow. "
|
| 73 |
+
"Please set it back to:"
|
| 74 |
+
f"\n{default_str}"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
output_vars = set(workflow.outputs)
|
| 78 |
+
for req_var in self.required_output_vars:
|
| 79 |
+
if req_var not in output_vars:
|
| 80 |
+
default_str = "[" + ", ".join([f"'{var}'" for var in self.required_output_vars]) + "]"
|
| 81 |
+
raise ValueError(
|
| 82 |
+
f"Missing required output variable: '{req_var}'. "
|
| 83 |
+
"\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values."
|
| 84 |
+
f"\nMake sure you have values set for all the outputs: {default_str}"
|
| 85 |
+
)
|
src/components/structs.py
CHANGED
|
@@ -143,6 +143,11 @@ class PipelineState(BaseModel):
|
|
| 143 |
update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
|
| 144 |
return self.model_copy(update=update)
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
|
| 147 |
"""Get all variables from all steps."""
|
| 148 |
available_variables = self.available_variables
|
|
@@ -170,3 +175,14 @@ class PipelineState(BaseModel):
|
|
| 170 |
|
| 171 |
class TossupPipelineState(PipelineState):
|
| 172 |
workflow: TossupWorkflow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
|
| 144 |
return self.model_copy(update=update)
|
| 145 |
|
| 146 |
+
def update_output_variable(self, target: str, produced_variable: str) -> "PipelineState":
|
| 147 |
+
"""Update the output variables for a step."""
|
| 148 |
+
self.workflow.outputs[target] = produced_variable
|
| 149 |
+
return self
|
| 150 |
+
|
| 151 |
def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
|
| 152 |
"""Get all variables from all steps."""
|
| 153 |
available_variables = self.available_variables
|
|
|
|
| 175 |
|
| 176 |
class TossupPipelineState(PipelineState):
|
| 177 |
workflow: TossupWorkflow
|
| 178 |
+
|
| 179 |
+
def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "TossupPipelineState":
|
| 180 |
+
"""Update a step in the pipeline."""
|
| 181 |
+
state = super().update_step(step, ui_state)
|
| 182 |
+
state.workflow = state.workflow.refresh_buzzer()
|
| 183 |
+
return state
|
| 184 |
+
|
| 185 |
+
def update_output_variable(self, target: str, produced_variable: str) -> "TossupPipelineState":
|
| 186 |
+
state = super().update_output_variable(target, produced_variable)
|
| 187 |
+
state.workflow = state.workflow.refresh_buzzer()
|
| 188 |
+
return state
|
src/workflows/validators.py
CHANGED
|
@@ -13,7 +13,6 @@ SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "lis
|
|
| 13 |
MAX_FIELD_NAME_LENGTH = 50
|
| 14 |
MAX_DESCRIPTION_LENGTH = 200
|
| 15 |
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
| 16 |
-
MIN_TEMPERATURE = 0.0
|
| 17 |
MAX_TEMPERATURE = 10.0
|
| 18 |
|
| 19 |
|
|
@@ -40,7 +39,7 @@ class ValidationError:
|
|
| 40 |
field_name: Optional[str] = None
|
| 41 |
|
| 42 |
|
| 43 |
-
class WorkflowValidationError(
|
| 44 |
"""Base class for workflow validation errors"""
|
| 45 |
|
| 46 |
def __init__(self, errors: list[ValidationError]):
|
|
@@ -77,9 +76,18 @@ def create_step_dep_graph(workflow: Workflow) -> dict[str, set[str]]:
|
|
| 77 |
class WorkflowValidator:
|
| 78 |
"""Validates workflows for correctness and consistency"""
|
| 79 |
|
| 80 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
self.errors: list[ValidationError] = []
|
| 82 |
self.workflow: Optional[Workflow] = None
|
|
|
|
|
|
|
| 83 |
|
| 84 |
def validate(self, workflow: Workflow) -> bool:
|
| 85 |
"""Main validation entry point"""
|
|
@@ -272,7 +280,7 @@ class WorkflowValidator:
|
|
| 272 |
self.errors.append(
|
| 273 |
ValidationError(
|
| 274 |
ValidationErrorType.NAMING,
|
| 275 |
-
f"Invalid step ID format: {step.id}. Must be a valid
|
| 276 |
step.id,
|
| 277 |
)
|
| 278 |
)
|
|
@@ -286,11 +294,11 @@ class WorkflowValidator:
|
|
| 286 |
)
|
| 287 |
return False
|
| 288 |
|
| 289 |
-
if not
|
| 290 |
self.errors.append(
|
| 291 |
ValidationError(
|
| 292 |
ValidationErrorType.RANGE,
|
| 293 |
-
f"Temperature must be between {
|
| 294 |
step.id,
|
| 295 |
)
|
| 296 |
)
|
|
@@ -304,11 +312,11 @@ class WorkflowValidator:
|
|
| 304 |
)
|
| 305 |
return False
|
| 306 |
|
| 307 |
-
if len(step.system_prompt) >
|
| 308 |
self.errors.append(
|
| 309 |
ValidationError(
|
| 310 |
ValidationErrorType.LENGTH,
|
| 311 |
-
f"System prompt exceeds maximum length of {
|
| 312 |
step.id,
|
| 313 |
)
|
| 314 |
)
|
|
@@ -365,22 +373,22 @@ class WorkflowValidator:
|
|
| 365 |
return False
|
| 366 |
|
| 367 |
# Validate field name length
|
| 368 |
-
if len(field.name) >
|
| 369 |
self.errors.append(
|
| 370 |
ValidationError(
|
| 371 |
ValidationErrorType.LENGTH,
|
| 372 |
-
f"Field name exceeds maximum length of {
|
| 373 |
field_name=field.name,
|
| 374 |
)
|
| 375 |
)
|
| 376 |
return False
|
| 377 |
|
| 378 |
# Validate description length
|
| 379 |
-
if len(field.description) >
|
| 380 |
self.errors.append(
|
| 381 |
ValidationError(
|
| 382 |
ValidationErrorType.LENGTH,
|
| 383 |
-
f"Description exceeds maximum length of {
|
| 384 |
field_name=field.name,
|
| 385 |
)
|
| 386 |
)
|
|
@@ -422,22 +430,22 @@ class WorkflowValidator:
|
|
| 422 |
return False
|
| 423 |
|
| 424 |
# Validate field name length
|
| 425 |
-
if len(field.name) >
|
| 426 |
self.errors.append(
|
| 427 |
ValidationError(
|
| 428 |
ValidationErrorType.LENGTH,
|
| 429 |
-
f"Field name exceeds maximum length of {
|
| 430 |
field_name=field.name,
|
| 431 |
)
|
| 432 |
)
|
| 433 |
return False
|
| 434 |
|
| 435 |
# Validate description length
|
| 436 |
-
if len(field.description) >
|
| 437 |
self.errors.append(
|
| 438 |
ValidationError(
|
| 439 |
ValidationErrorType.LENGTH,
|
| 440 |
-
f"Description exceeds maximum length of {
|
| 441 |
field_name=field.name,
|
| 442 |
)
|
| 443 |
)
|
|
@@ -545,10 +553,6 @@ class WorkflowValidator:
|
|
| 545 |
|
| 546 |
def _is_valid_identifier(self, name: str) -> bool:
|
| 547 |
"""Validates if a string is a valid Python identifier"""
|
| 548 |
-
if
|
| 549 |
-
return
|
| 550 |
-
|
| 551 |
-
return False
|
| 552 |
-
if not name.strip(): # Check for whitespace-only strings
|
| 553 |
-
return False
|
| 554 |
-
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
|
|
|
|
| 13 |
MAX_FIELD_NAME_LENGTH = 50
|
| 14 |
MAX_DESCRIPTION_LENGTH = 200
|
| 15 |
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
|
|
|
| 16 |
MAX_TEMPERATURE = 10.0
|
| 17 |
|
| 18 |
|
|
|
|
| 39 |
field_name: Optional[str] = None
|
| 40 |
|
| 41 |
|
| 42 |
+
class WorkflowValidationError(ValueError):
|
| 43 |
"""Base class for workflow validation errors"""
|
| 44 |
|
| 45 |
def __init__(self, errors: list[ValidationError]):
|
|
|
|
| 76 |
class WorkflowValidator:
|
| 77 |
"""Validates workflows for correctness and consistency"""
|
| 78 |
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
min_temperature: float = 0,
|
| 82 |
+
max_temperature: float = MAX_TEMPERATURE,
|
| 83 |
+
max_field_name_length: int = MAX_FIELD_NAME_LENGTH,
|
| 84 |
+
max_description_length: int = MAX_DESCRIPTION_LENGTH,
|
| 85 |
+
max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH,
|
| 86 |
+
):
|
| 87 |
self.errors: list[ValidationError] = []
|
| 88 |
self.workflow: Optional[Workflow] = None
|
| 89 |
+
self.min_temperature = min_temperature
|
| 90 |
+
self.max_temperature = max_temperature
|
| 91 |
|
| 92 |
def validate(self, workflow: Workflow) -> bool:
|
| 93 |
"""Main validation entry point"""
|
|
|
|
| 280 |
self.errors.append(
|
| 281 |
ValidationError(
|
| 282 |
ValidationErrorType.NAMING,
|
| 283 |
+
f"Invalid step ID format: {step.id}. Must be a valid identifier.",
|
| 284 |
step.id,
|
| 285 |
)
|
| 286 |
)
|
|
|
|
| 294 |
)
|
| 295 |
return False
|
| 296 |
|
| 297 |
+
if not self.min_temperature <= step.temperature <= self.max_temperature:
|
| 298 |
self.errors.append(
|
| 299 |
ValidationError(
|
| 300 |
ValidationErrorType.RANGE,
|
| 301 |
+
f"Temperature must be between {self.min_temperature} and {self.max_temperature}",
|
| 302 |
step.id,
|
| 303 |
)
|
| 304 |
)
|
|
|
|
| 312 |
)
|
| 313 |
return False
|
| 314 |
|
| 315 |
+
if len(step.system_prompt) > self.max_system_prompt_length:
|
| 316 |
self.errors.append(
|
| 317 |
ValidationError(
|
| 318 |
ValidationErrorType.LENGTH,
|
| 319 |
+
f"System prompt exceeds maximum length of {self.max_system_prompt_length} characters",
|
| 320 |
step.id,
|
| 321 |
)
|
| 322 |
)
|
|
|
|
| 373 |
return False
|
| 374 |
|
| 375 |
# Validate field name length
|
| 376 |
+
if len(field.name) > self.max_field_name_length:
|
| 377 |
self.errors.append(
|
| 378 |
ValidationError(
|
| 379 |
ValidationErrorType.LENGTH,
|
| 380 |
+
f"Field name exceeds maximum length of {self.max_field_name_length} characters",
|
| 381 |
field_name=field.name,
|
| 382 |
)
|
| 383 |
)
|
| 384 |
return False
|
| 385 |
|
| 386 |
# Validate description length
|
| 387 |
+
if len(field.description) > self.max_description_length:
|
| 388 |
self.errors.append(
|
| 389 |
ValidationError(
|
| 390 |
ValidationErrorType.LENGTH,
|
| 391 |
+
f"Description exceeds maximum length of {self.max_description_length} characters",
|
| 392 |
field_name=field.name,
|
| 393 |
)
|
| 394 |
)
|
|
|
|
| 430 |
return False
|
| 431 |
|
| 432 |
# Validate field name length
|
| 433 |
+
if len(field.name) > self.max_field_name_length:
|
| 434 |
self.errors.append(
|
| 435 |
ValidationError(
|
| 436 |
ValidationErrorType.LENGTH,
|
| 437 |
+
f"Field name exceeds maximum length of {self.max_field_name_length} characters",
|
| 438 |
field_name=field.name,
|
| 439 |
)
|
| 440 |
)
|
| 441 |
return False
|
| 442 |
|
| 443 |
# Validate description length
|
| 444 |
+
if len(field.description) > self.max_description_length:
|
| 445 |
self.errors.append(
|
| 446 |
ValidationError(
|
| 447 |
ValidationErrorType.LENGTH,
|
| 448 |
+
f"Description exceeds maximum length of {self.max_description_length} characters",
|
| 449 |
field_name=field.name,
|
| 450 |
)
|
| 451 |
)
|
|
|
|
| 553 |
|
| 554 |
def _is_valid_identifier(self, name: str) -> bool:
|
| 555 |
"""Validates if a string is a valid Python identifier"""
|
| 556 |
+
if name and name.strip():
|
| 557 |
+
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
|
| 558 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|