import json from pathlib import Path from os import PathLike from typing import List, Literal from itertools import repeat import matplotlib.pyplot as plt from matplotlib.figure import Figure import seaborn as sns import pandas as pd from pytorch_fob.engine.parser import YAMLParser from pytorch_fob.engine.utils import AttributeDict, convert_type_inside_dict, log_warn, log_info, log_debug from pytorch_fob.evaluation import evaluation_path def get_available_trials(dirname: Path, config: AttributeDict, depth: int = 1): """finds the path for all trials in the *dirname* directory""" # RECURSIVELY FIND ALL DIRS IN DIRNAME (up to depth) assert isinstance(dirname, Path) subdirs: list[Path] = [dirname] all_results_must_be_same_depth = True for _ in range(depth): if all_results_must_be_same_depth: new_subdirs: list[Path] = [] for subdir in subdirs: new_subdirs += [x for x in subdir.iterdir() if x.is_dir()] subdirs = new_subdirs else: for subdir in subdirs: subdirs += [x for x in subdir.iterdir() if x.is_dir()] format_str = "\n " # f-string expression part cannot include a backslash log_debug(f"found the following directories:{format_str}{format_str.join(str(i) for i in subdirs)}.") def is_trial(path: Path): # here we could do additional checks to filter the subdirectories # currently we only check if there is a config file for x in path.iterdir(): found_a_config_file = x.name == config.experiment_files.config if found_a_config_file: return True return False subdirs = list(filter(is_trial, subdirs[::-1])) log_debug(f"We assume the following to be trials:{format_str}{format_str.join(str(i) for i in subdirs)}.") return subdirs def dataframe_from_trials(trial_dir_paths: List[Path], config: AttributeDict) -> pd.DataFrame: """takes result from get_available_trials and packs them in a dataframe, does not filter duplicate hyperparameter settings.""" dfs: List[pd.DataFrame] = [] for path in trial_dir_paths: config_file = path / config.experiment_files.config if config.last_instead_of_best: result_file = path / config.experiment_files.last_model else: result_file = path / config.experiment_files.best_model all_files_exist = all([ config_file.is_file(), result_file.is_file() ]) if not all_files_exist: log_warn(f"WARNING: one or more files are missing in {path}. Skipping this hyperparameter setting." + f" <{config_file}>: {config_file.is_file()} and\n <{result_file}>: {result_file.is_file()})") continue yaml_parser = YAMLParser() yaml_content = yaml_parser.parse_yaml(config_file) # convert the sub dicts first, then the dict itself yaml_content = convert_type_inside_dict(yaml_content, src=dict, tgt=AttributeDict) yaml_content = AttributeDict(yaml_content) # use user given value metric_of_value_to_plot = config.plot.metric # compute it if user has not given a value if not metric_of_value_to_plot: raise ValueError("evaluation.plot.metric is not set") data = pd.json_normalize(yaml_content) with open(result_file, "r", encoding="utf8") as f: content = json.load(f) if metric_of_value_to_plot in content[0]: data.at[0, metric_of_value_to_plot] = content[0][metric_of_value_to_plot] else: log_warn(f"could not find value for {metric_of_value_to_plot} in json") dfs.append(data) if len(dfs) == 0: raise ValueError("no dataframes found, check your config") df = pd.concat(dfs, sort=False) return df def create_matrix_plot(dataframe: pd.DataFrame, config: AttributeDict, cols: str, idx: str, ax=None, cbar: bool = True, vmin: None | int = None, vmax: None | int = None): """ Creates one heatmap and puts it into the grid of subplots. Uses pd.pivot_table() and sns.heatmap(). """ df_entry = dataframe.iloc[0] metric_name = df_entry["evaluation.plot.metric"] # CLEANING LAZY USER INPUT # cols are x-axis, idx are y-axis if cols not in dataframe.columns: log_warn("x-axis value not present in the dataframe; did you forget to add a 'optimizer.' as a prefix?\n" + f" using '{'optimizer.' + cols}' as 'x-axis' instead.") cols = "optimizer." + cols if idx not in dataframe.columns: log_warn("y-axis value not present in the dataframe; did you forget to add a 'optimizer.' as a prefix?\n" + f" using '{'optimizer.' + idx}' as 'y-axis' instead.") idx = "optimizer." + idx # create pivot table and format the score result pivot_table = pd.pivot_table(dataframe, columns=cols, index=idx, values=metric_name, aggfunc='mean') fmt = None format_string = dataframe["evaluation.plot.format"].iloc[0] # scaline the values given by the user to fit his format needs (-> and adapting the limits) value_exp_factor, decimal_points = format_string.split(".") value_exp_factor = int(value_exp_factor) decimal_points = int(decimal_points) if vmin: vmin *= (10 ** value_exp_factor) if vmax: vmax *= (10 ** value_exp_factor) pivot_table = (pivot_table * (10 ** value_exp_factor)).round(decimal_points) fmt=f".{decimal_points}f" # up to here limits was the min and max over all dataframes, # usually we want to use user values if "evaluation.plot.limits" in dataframe.columns: limits = dataframe["evaluation.plot.limits"].iloc[0] if limits: vmin = min(limits) vmax = max(limits) log_debug(f"setting cbar limits to {vmin}, {vmax} ") colormap_name = config.plotstyle.color_palette low_is_better = dataframe["evaluation.plot.test_metric_mode"].iloc[0] == "min" if low_is_better: colormap_name += "_r" # this will "inver" / "flip" the colorbar colormap = sns.color_palette(colormap_name, as_cmap=True) metric_legend = pretty_name(metric_name) # FINETUNE POSITION # left bottom width height # cbar_ax = fig.add_axes([0.92, 0.235, 0.02, 0.6]) cbar_ax = None if not config.plot.std: return sns.heatmap(pivot_table, ax=ax, cbar_ax=cbar_ax, annot=True, fmt=fmt, annot_kws={'fontsize': config.plotstyle.matrix_font.size}, cbar=cbar, vmin=vmin, vmax=vmax, cmap=colormap, cbar_kws={'label': f"{metric_legend}"}) else: # BUILD STD TABLE pivot_table_std = pd.pivot_table(dataframe, columns=cols, index=idx, values=metric_name, aggfunc=config.plot.aggfunc, fill_value=float("inf"), dropna=False ) if float("inf") in pivot_table_std.values.flatten(): log_warn("WARNING: Not enough data to calculate the std, skipping std in plot") pivot_table_std = (pivot_table_std * (10 ** value_exp_factor)).round(decimal_points) annot_matrix = pivot_table.copy().astype("string") for i in pivot_table.index: for j in pivot_table.columns: mean = pivot_table.loc[i, j] std = pivot_table_std.loc[i, j] std_string = f"\n±({round(std, decimal_points)})" if std != float("inf") else "" # type: ignore annot_matrix.loc[i, j] = f"{round(mean, decimal_points)}{std_string}" # type: ignore fmt = "" # cannot format like before, as we do not only have a number return sns.heatmap(pivot_table, ax=ax, cbar_ax=cbar_ax, annot=annot_matrix, fmt=fmt, annot_kws={'fontsize': config.plotstyle.matrix_font.size}, cbar=cbar, vmin=vmin, vmax=vmax, cmap=colormap, cbar_kws={'label': f"{metric_legend}"}) def get_all_num_rows_and_their_names(dataframe_list: list[pd.DataFrame], config): n_rows: list[int] = [] row_names: list[list[str]] = [] for i, df in enumerate(dataframe_list): x_axis = config.plot.x_axis[i] y_axis = config.plot.y_axis[0] metrics = df["evaluation.plot.metric"].unique() ignored_cols = [x_axis, y_axis] ignored_cols += list(metrics) ignored_cols += config.get("ignore_keys", []) ignored_cols += config.get("aggregate_groups", []) current_n_rows, current_names = get_num_rows(df, ignored_cols, config) n_rows.append(current_n_rows) if not current_names: # will be empty if we have only one row current_names.append("default") row_names.append(current_names) return n_rows, row_names def get_num_rows(dataframe: pd.DataFrame, ignored_cols: list[str], config: AttributeDict ) -> tuple[int, list[str]]: """each matrix has 2 params (on for x and y each), one value, and we aggregate over seeds; if there are more than than these 4 parameter with different values, we want to put that in seperate rows instead of aggregating over them. returning: the number of rows (atleast 1) and the names of the cols""" necesarry_rows = 0 # the user might specify a value for the groups that we should split on in whitelisted_cols: list[str] | Literal["all"] = "all" # everything is whitelisted if this value stays 'all' if isinstance(config.split_groups, list): whitelisted_cols = config.split_groups[:] elif config.split_groups is False: whitelisted_cols = [] columns_with_non_unique_values = [] for col in dataframe.columns: is_eval_key = col.startswith("evaluation.") is_ignored = col in ignored_cols is_whitelisted = whitelisted_cols == "all" or col in whitelisted_cols if any([is_ignored, is_eval_key, not is_whitelisted]): if is_whitelisted: log_warn(f"{col} is in the whitelist, but will be ignored. Probably {col} is in both 'split_groups' and 'aggregate_groups'.") log_debug(f"ignoring {col}") continue nunique = dataframe[col].nunique(dropna=False) if nunique > 1: log_debug(f"adding {col} since there are {nunique} unique values") for unique_hp in dataframe[col].unique(): columns_with_non_unique_values.append(f"{col}={unique_hp}") necesarry_rows += (nunique) # each unique parameter should be an individal plot rows_number = max(necesarry_rows, 1) col_names = columns_with_non_unique_values log_debug(f"{rows_number=}") log_debug(f"{col_names=}") return rows_number, col_names def find_global_vmin_vmax(dataframe_list, config): vmin: int | float | None = None vmax: int | float | None = None num_cols = len(dataframe_list) if num_cols > 1: # all subplots should have same colors -> we need to find the limits vmin = float('inf') vmax = float('-inf') for i in range(num_cols): dataframe = dataframe_list[i] cols = config.plot.x_axis[i] idx = config.plot.y_axis[0] key = config.plot.metric pivot_table = pd.pivot_table(dataframe, columns=cols, index=idx, values=key, aggfunc='mean') min_value_present_in_current_df = pivot_table.min().min() max_value_present_in_current_df = pivot_table.max().max() log_debug("colorbar_limits:\n" + f" subfigure number {i+1}, checking for metric {key}: \n" + f" min value is {min_value_present_in_current_df},\n" + f" max value is {max_value_present_in_current_df}") vmin = min(vmin, min_value_present_in_current_df) vmax = max(vmax, max_value_present_in_current_df) return vmin, vmax def create_figure(dataframe_list: list[pd.DataFrame], config: AttributeDict): """ Takes a list of dataframes. Each dataframe is processed into a column of heatmaps. """ num_cols: int = len(dataframe_list) # calculate the number of rows for each dataframe n_rows, row_names = get_all_num_rows_and_their_names(dataframe_list, config) # Handling of the number of rows in the plot # we could either create a full rectangular grid, or allow each subplot to nest subplots # for nesting we would need to create subfigures instead of subplots i think if config.split_groups is False: n_rows_max = 1 row_names = [["default"] for _ in range(num_cols)] else: n_rows_max = max(n_rows) log_debug(f"{n_rows=} and {num_cols=}") # TODO, figsize was just hardcoded for (1, 2) grid and left to default for (1, 1) grid # probably not worth the hazzle to create something dynamic (atleast not now) # EDIT: it was slightly adapted to allow num rows without being completely unreadable # margin = (num_subfigures - 1) * 0.3 # figsize=(5*n_cols + margin, 2.5) scale = config.plotstyle.scale if num_cols == 1 and n_rows_max > 1: figsize = (2**3 * scale, 2 * 3 * n_rows_max * scale) elif num_cols == 2: # TODO: after removing cbar from left subifgure, it is squished # there is an argument to share the legend, we should use that figsize = (12 * scale, 5.4 * n_rows_max * scale) elif num_cols > 2: figsize = (12 * (num_cols / 2) * scale, 5.4 * n_rows_max * scale) else: figsize = None # TODO: use seaborn FacetGrid fig, axs = plt.subplots(n_rows_max, num_cols, figsize=figsize) if n_rows_max == 1: axs = [axs] if num_cols == 1: axs = [[ax] for ax in axs] # adapt for special case so we have unified types # Adjust left and right margins as needed # fig.subplots_adjust(left=0.1, right=0.9, top=0.97, hspace=0.38, bottom=0.05,wspace=0.3) # None -> plt will chose vmin and vmax vmin, vmax = find_global_vmin_vmax(dataframe_list, config) for i in range(num_cols): num_nested_subfigures: int = n_rows[i] if not config.split_groups: create_one_grid_element(dataframe_list, config, axs, i, j=0, max_i=num_cols, max_j=0, vmin=vmin, vmax=vmax, n_rows=n_rows, row_names=row_names) else: for j in range(num_nested_subfigures): create_one_grid_element(dataframe_list, config, axs, i, j, max_i=num_cols, max_j=num_nested_subfigures, vmin=vmin, vmax=vmax, n_rows=n_rows, row_names=row_names) if config.plotstyle.tight_layout: fig.tight_layout() # SUPTITLE (the super title on top of the whole figure in the middle) # # TODO super title might be squished when used together with tight layout (removing for now) # if n_rows_max > 1 or num_cols > 1: # # set experiment name as title when multiple matrices in image # if config.experiment_name: # fig.suptitle(config.experiment_name) return fig, axs def create_one_grid_element(dataframe_list: list[pd.DataFrame], config: AttributeDict, axs, i: int, j: int, max_i: int, max_j: int, vmin, vmax, n_rows, row_names): """does one 'axs' element as it is called in plt""" num_nested_subfigures: int = n_rows[i] name_for_additional_subplots: list[str] = row_names[i] num_subfigures = max_i # from left to right num_nested_subfigures = max_j # from top to bottom dataframe = dataframe_list[i] cols = config.plot.x_axis[i] idx = config.plot.y_axis[0] # only include colorbar once include_cbar: bool = i == num_subfigures - 1 model_param = name_for_additional_subplots[j] if model_param == "default": current_dataframe = dataframe # we do not need to do further grouping else: param_name, param_value = model_param.split("=", maxsplit=1) if pd.api.types.is_numeric_dtype(dataframe[param_name]): param_value = float(param_value) try: current_dataframe = dataframe.groupby([param_name]).get_group((param_value,)) except KeyError: log_warn(f"WARNING: was not able to groupby '{param_name}'," + "maybe the data was created with different versions of fob; skipping this row") log_debug(f"{param_name=}{param_value=}{dataframe.columns=}{dataframe[param_name]=}") return False current_plot = create_matrix_plot(current_dataframe, config, cols, idx, ax=axs[j][i], cbar=include_cbar, vmin=vmin, vmax=vmax) # LABELS # Pretty name for label "learning_rate" => "Learning Rate" # remove x_label of all but last row, remove y_label for all but first column if i > 0: current_plot.set_ylabel('', labelpad=8) else: current_plot.set_ylabel(pretty_name(current_plot.get_ylabel())) if j < num_nested_subfigures - 1: current_plot.set_xlabel('', labelpad=8) else: current_plot.set_xlabel(pretty_name(current_plot.get_xlabel())) # reading optimizer and task name after grouping df_entry = current_dataframe.iloc[0] # just get an arbitrary trial opti_name = df_entry['optimizer.name'] task_name = df_entry['task.name'] # TITLE # title (heading) of the heatmap: on (+ additional info) title = f"{pretty_name(opti_name)} on {pretty_name(task_name)}" if max_i > 1 or max_j > 1: title += "" if model_param == "default" else f"\n{model_param}" current_plot.set_title(title) def extract_dataframes(workload_paths: List[Path], config: AttributeDict, depth: int = 1 ) -> list[pd.DataFrame]: df_list: list[pd.DataFrame] = [] num_dataframes: int = len(workload_paths) for i in range(num_dataframes): available_trials = get_available_trials(workload_paths[i], config, depth) dataframe = dataframe_from_trials(available_trials, config) df_list.append(dataframe) return df_list def get_output_file_path(dataframe_list: list[pd.DataFrame], config: AttributeDict, suffix: str = "") -> Path: task_names = [df.iloc[0]["task.name"] for df in dataframe_list] optim_names = [df.iloc[0]["optimizer.name"] for df in dataframe_list] task_name = "_".join(sorted(set(task_names))) optim_name = "_".join(sorted(set(optim_names))) here = Path(__file__).parent.resolve() output_dir = Path(config.output_dir) if config.output_dir else here experiment_name = Path(config.experiment_name) if config.experiment_name else f"{optim_name}-{task_name}" output_file_path = output_dir / experiment_name return Path(f"{output_file_path}-{suffix}" if suffix else output_file_path) def set_plotstyle(config: AttributeDict): plt.rcParams["text.usetex"] = config.plotstyle.text.usetex plt.rcParams["font.family"] = config.plotstyle.font.family plt.rcParams["font.size"] = config.plotstyle.font.size def pretty_name(name: str, pretty_names: dict | str = {}) -> str: # type: ignore pylint: disable=dangerous-default-value """ Tries to use a mapping for the name, else will do some general replacement. mapping can be a directory or a filename of a yaml file with 'names' key """ # reading from yaml and caching the dictionary label_file: Path = evaluation_path() / "labels.yaml" if isinstance(pretty_names, str): label_file = Path(pretty_names) if pretty_names == {} or isinstance(pretty_names, str): yaml_parser = YAMLParser() yaml_content = yaml_parser.parse_yaml(label_file) pretty_names: dict[str, str] = yaml_content["names"] # applying pretty names name_without_yaml_prefix = name.split(".")[-1] if name in pretty_names.keys(): name = pretty_names[name] elif name_without_yaml_prefix in pretty_names.keys(): name = pretty_names[name_without_yaml_prefix] else: name = name.replace('_', ' ').title() return name def save_csv(dfs: list[pd.DataFrame], output_filename: Path): for i, df in enumerate(dfs): csv_output_filename = f"{output_filename.resolve()}-{i}.csv" log_info(f"saving raw data as {csv_output_filename}") df.to_csv(path_or_buf=csv_output_filename, index=False) def save_plot(fig: Figure, output_file_path: Path, file_type: str, dpi: int): plot_output_filename = f"{output_file_path.resolve()}.{file_type}" log_info(f"saving figure as <{plot_output_filename}>") fig.savefig(plot_output_filename, dpi=dpi) def save_files(fig, dfs: list[pd.DataFrame], output_file_path: Path, config: AttributeDict): output_file_path.parent.mkdir(parents=True, exist_ok=True) for file_type in config.output_types: if file_type == "csv": save_csv(dfs, output_file_path) elif file_type == "png" or file_type == "pdf": save_plot(fig, output_file_path, file_type, config.plotstyle.dpi) def clean_config(config: AttributeDict) -> AttributeDict: """some processing that allows the user to be lazy, shortcut for the namespace, hidden values are found and config.all_values""" if "evaluation" in config.keys(): evaluation_config: AttributeDict = config.evaluation evaluation_config["all_values"] = config config = evaluation_config else: log_warn("there is no 'evaluation' in the yaml provided!") if "data_dirs" in config.keys(): value_is_none = not config.data_dirs value_has_wrong_type = not isinstance(config.data_dirs, (PathLike, str, list)) if value_is_none or value_has_wrong_type: raise ValueError(f"Error: 'evaluation.data_dirs' was not provided correctly! check for typos in the yaml provided! value given: {config.data_dirs}") # allow the user to write a single string instead of a list of strings if not isinstance(config.output_types, list): config["output_types"] = [config.output_types] log_info("fixing value for key to be a list[str]") if not isinstance(config.data_dirs, list): config["data_dirs"] = [Path(config.data_dirs)] log_info("fixing value for key to be a list[Path]") # x_axis if not isinstance(config.plot.x_axis, list): config["plot"]["x_axis"] = [config.plot.x_axis] log_info("fixing value for key to be a list[str]") if len(config.plot.x_axis) < len(config.data_dirs): # use same x axis for all if only one given missing_elements = len(config.data_dirs) - len(config.plot.x_axis) config["plot"]["x_axis"] += repeat(config.plot.x_axis[0], missing_elements) # y_axis if not isinstance(config.plot.y_axis, list): config["plot"]["y_axis"] = [config.plot.y_axis] log_info("fixing value for key to be a list[str]") if len(config.plot.y_axis) < len(config.data_dirs): # use same x axis for all if only one given missing_elements = len(config.data_dirs) - len(config.plot.y_axis) config["plot"]["y_axis"] += repeat(config.plot.y_axis[0], missing_elements) return config def main(config: AttributeDict): config = clean_config(config) # sets config to config.evaluation, cleans some data workloads: List[Path] = [Path(name) for name in config.data_dirs] log_debug(f"{workloads}=") set_plotstyle(config) dfs = extract_dataframes(workloads, depth=config.depth, config=config) fig, _ = create_figure(dfs, config) output_file_path = get_output_file_path(dfs, config) save_files(fig, dfs, output_file_path, config)