Spaces:
Running
Running
| # Copyright 2024 the LlamaFactory team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import math | |
| import os | |
| from typing import Any, Dict, List | |
| from transformers.trainer import TRAINER_STATE_NAME | |
| from .logging import get_logger | |
| from .packages import is_matplotlib_available | |
| if is_matplotlib_available(): | |
| import matplotlib.figure | |
| import matplotlib.pyplot as plt | |
| logger = get_logger(__name__) | |
| def smooth(scalars: List[float]) -> List[float]: | |
| r""" | |
| EMA implementation according to TensorBoard. | |
| """ | |
| if len(scalars) == 0: | |
| return [] | |
| last = scalars[0] | |
| smoothed = [] | |
| weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function | |
| for next_val in scalars: | |
| smoothed_val = last * weight + (1 - weight) * next_val | |
| smoothed.append(smoothed_val) | |
| last = smoothed_val | |
| return smoothed | |
| def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": | |
| r""" | |
| Plots loss curves in LlamaBoard. | |
| """ | |
| plt.close("all") | |
| plt.switch_backend("agg") | |
| fig = plt.figure() | |
| ax = fig.add_subplot(111) | |
| steps, losses = [], [] | |
| for log in trainer_log: | |
| if log.get("loss", None): | |
| steps.append(log["current_steps"]) | |
| losses.append(log["loss"]) | |
| ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original") | |
| ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed") | |
| ax.legend() | |
| ax.set_xlabel("step") | |
| ax.set_ylabel("loss") | |
| return fig | |
| def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: | |
| r""" | |
| Plots loss curves and saves the image. | |
| """ | |
| plt.switch_backend("agg") | |
| with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| for key in keys: | |
| steps, metrics = [], [] | |
| for i in range(len(data["log_history"])): | |
| if key in data["log_history"][i]: | |
| steps.append(data["log_history"][i]["step"]) | |
| metrics.append(data["log_history"][i][key]) | |
| if len(metrics) == 0: | |
| logger.warning(f"No metric {key} to plot.") | |
| continue | |
| plt.figure() | |
| plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") | |
| plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") | |
| plt.title("training {} of {}".format(key, save_dictionary)) | |
| plt.xlabel("step") | |
| plt.ylabel(key) | |
| plt.legend() | |
| figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_"))) | |
| plt.savefig(figure_path, format="png", dpi=100) | |
| print("Figure saved at:", figure_path) | |