| from typing import List, Literal | |
| import numpy as np | |
| def generate_parameters_with_timesteps( | |
| start: int, | |
| num: int, | |
| stop: int = None, | |
| method: Literal["linear", "two_stage", "three_stage", "fix_two_stage"] = "linear", | |
| n_fix_start: int = 3, | |
| ) -> List[float]: | |
| if stop is None or start == stop: | |
| params = [start] * num | |
| else: | |
| if method == "linear": | |
| params = generate_linear_parameters(start, stop, num) | |
| elif method == "two_stage": | |
| params = generate_two_stages_parameters(start, stop, num) | |
| elif method == "three_stage": | |
| params = generate_three_stages_parameters(start, stop, num) | |
| elif method == "fix_two_stage": | |
| params = generate_fix_two_stages_parameters(start, stop, num, n_fix_start) | |
| else: | |
| raise ValueError( | |
| f"now only support linear, two_stage, three_stage, but given{method}" | |
| ) | |
| return params | |
| def generate_linear_parameters(start, stop, num): | |
| parames = list( | |
| np.linspace( | |
| start=start, | |
| stop=stop, | |
| num=num, | |
| ) | |
| ) | |
| return parames | |
| def generate_two_stages_parameters(start, stop, num): | |
| num_start = num // 2 | |
| num_end = num - num_start | |
| parames = [start] * num_start + [stop] * num_end | |
| return parames | |
| def generate_fix_two_stages_parameters(start, stop, num, n_fix_start: int) -> List: | |
| num_start = n_fix_start | |
| num_end = num - num_start | |
| parames = [start] * num_start + [stop] * num_end | |
| return parames | |
| def generate_three_stages_parameters(start, stop, num): | |
| middle = (start + stop) // 2 | |
| num_start = num // 3 | |
| num_middle = num_start | |
| num_end = num - num_start - num_middle | |
| parames = [start] * num_start + [middle] * num_middle + [stop] * num_end | |
| return parames | |