mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
564 lines
24 KiB
Python
564 lines
24 KiB
Python
import json
|
|
from pathlib import Path
|
|
from os import PathLike
|
|
from typing import List, Literal
|
|
from itertools import repeat
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.figure import Figure
|
|
import seaborn as sns
|
|
import pandas as pd
|
|
from pytorch_fob.engine.parser import YAMLParser
|
|
from pytorch_fob.engine.utils import AttributeDict, convert_type_inside_dict, log_warn, log_info, log_debug
|
|
from pytorch_fob.evaluation import evaluation_path
|
|
|
|
|
|
def get_available_trials(dirname: Path, config: AttributeDict, depth: int = 1):
|
|
"""finds the path for all trials in the *dirname* directory"""
|
|
# RECURSIVELY FIND ALL DIRS IN DIRNAME (up to depth)
|
|
assert isinstance(dirname, Path)
|
|
subdirs: list[Path] = [dirname]
|
|
all_results_must_be_same_depth = True
|
|
for _ in range(depth):
|
|
if all_results_must_be_same_depth:
|
|
new_subdirs: list[Path] = []
|
|
for subdir in subdirs:
|
|
new_subdirs += [x for x in subdir.iterdir() if x.is_dir()]
|
|
subdirs = new_subdirs
|
|
else:
|
|
for subdir in subdirs:
|
|
subdirs += [x for x in subdir.iterdir() if x.is_dir()]
|
|
format_str = "\n " # f-string expression part cannot include a backslash
|
|
log_debug(f"found the following directories:{format_str}{format_str.join(str(i) for i in subdirs)}.")
|
|
|
|
def is_trial(path: Path):
|
|
# here we could do additional checks to filter the subdirectories
|
|
# currently we only check if there is a config file
|
|
for x in path.iterdir():
|
|
found_a_config_file = x.name == config.experiment_files.config
|
|
if found_a_config_file:
|
|
return True
|
|
return False
|
|
|
|
subdirs = list(filter(is_trial, subdirs[::-1]))
|
|
log_debug(f"We assume the following to be trials:{format_str}{format_str.join(str(i) for i in subdirs)}.")
|
|
return subdirs
|
|
|
|
|
|
def dataframe_from_trials(trial_dir_paths: List[Path], config: AttributeDict) -> pd.DataFrame:
|
|
"""takes result from get_available_trials and packs them in a dataframe,
|
|
does not filter duplicate hyperparameter settings."""
|
|
dfs: List[pd.DataFrame] = []
|
|
|
|
for path in trial_dir_paths:
|
|
|
|
config_file = path / config.experiment_files.config
|
|
if config.last_instead_of_best:
|
|
result_file = path / config.experiment_files.last_model
|
|
else:
|
|
result_file = path / config.experiment_files.best_model
|
|
all_files_exist = all([
|
|
config_file.is_file(),
|
|
result_file.is_file()
|
|
])
|
|
if not all_files_exist:
|
|
log_warn(f"WARNING: one or more files are missing in {path}. Skipping this hyperparameter setting." +
|
|
f" <{config_file}>: {config_file.is_file()} and\n <{result_file}>: {result_file.is_file()})")
|
|
continue
|
|
|
|
yaml_parser = YAMLParser()
|
|
yaml_content = yaml_parser.parse_yaml(config_file)
|
|
# convert the sub dicts first, then the dict itself
|
|
yaml_content = convert_type_inside_dict(yaml_content, src=dict, tgt=AttributeDict)
|
|
yaml_content = AttributeDict(yaml_content)
|
|
|
|
# use user given value
|
|
metric_of_value_to_plot = config.plot.metric
|
|
|
|
# compute it if user has not given a value
|
|
if not metric_of_value_to_plot:
|
|
raise ValueError("evaluation.plot.metric is not set")
|
|
|
|
data = pd.json_normalize(yaml_content)
|
|
|
|
with open(result_file, "r", encoding="utf8") as f:
|
|
content = json.load(f)
|
|
if metric_of_value_to_plot in content[0]:
|
|
data.at[0, metric_of_value_to_plot] = content[0][metric_of_value_to_plot]
|
|
else:
|
|
log_warn(f"could not find value for {metric_of_value_to_plot} in json")
|
|
|
|
dfs.append(data)
|
|
|
|
if len(dfs) == 0:
|
|
raise ValueError("no dataframes found, check your config")
|
|
df = pd.concat(dfs, sort=False)
|
|
|
|
return df
|
|
|
|
|
|
def create_matrix_plot(dataframe: pd.DataFrame, config: AttributeDict, cols: str, idx: str, ax=None,
|
|
cbar: bool = True, vmin: None | int = None, vmax: None | int = None):
|
|
"""
|
|
Creates one heatmap and puts it into the grid of subplots.
|
|
Uses pd.pivot_table() and sns.heatmap().
|
|
"""
|
|
df_entry = dataframe.iloc[0]
|
|
metric_name = df_entry["evaluation.plot.metric"]
|
|
|
|
# CLEANING LAZY USER INPUT
|
|
# cols are x-axis, idx are y-axis
|
|
if cols not in dataframe.columns:
|
|
log_warn("x-axis value not present in the dataframe; did you forget to add a 'optimizer.' as a prefix?\n" +
|
|
f" using '{'optimizer.' + cols}' as 'x-axis' instead.")
|
|
cols = "optimizer." + cols
|
|
if idx not in dataframe.columns:
|
|
log_warn("y-axis value not present in the dataframe; did you forget to add a 'optimizer.' as a prefix?\n" +
|
|
f" using '{'optimizer.' + idx}' as 'y-axis' instead.")
|
|
idx = "optimizer." + idx
|
|
# create pivot table and format the score result
|
|
pivot_table = pd.pivot_table(dataframe,
|
|
columns=cols, index=idx, values=metric_name,
|
|
aggfunc='mean')
|
|
|
|
fmt = None
|
|
format_string = dataframe["evaluation.plot.format"].iloc[0]
|
|
|
|
# scaline the values given by the user to fit his format needs (-> and adapting the limits)
|
|
value_exp_factor, decimal_points = format_string.split(".")
|
|
value_exp_factor = int(value_exp_factor)
|
|
decimal_points = int(decimal_points)
|
|
if vmin:
|
|
vmin *= (10 ** value_exp_factor)
|
|
if vmax:
|
|
vmax *= (10 ** value_exp_factor)
|
|
pivot_table = (pivot_table * (10 ** value_exp_factor)).round(decimal_points)
|
|
fmt=f".{decimal_points}f"
|
|
|
|
# up to here limits was the min and max over all dataframes,
|
|
# usually we want to use user values
|
|
if "evaluation.plot.limits" in dataframe.columns:
|
|
limits = dataframe["evaluation.plot.limits"].iloc[0]
|
|
if limits:
|
|
vmin = min(limits)
|
|
vmax = max(limits)
|
|
log_debug(f"setting cbar limits to {vmin}, {vmax} ")
|
|
|
|
colormap_name = config.plotstyle.color_palette
|
|
low_is_better = dataframe["evaluation.plot.test_metric_mode"].iloc[0] == "min"
|
|
if low_is_better:
|
|
colormap_name += "_r" # this will "inver" / "flip" the colorbar
|
|
colormap = sns.color_palette(colormap_name, as_cmap=True)
|
|
metric_legend = pretty_name(metric_name)
|
|
|
|
# FINETUNE POSITION
|
|
# left bottom width height
|
|
# cbar_ax = fig.add_axes([0.92, 0.235, 0.02, 0.6])
|
|
cbar_ax = None
|
|
|
|
if not config.plot.std:
|
|
return sns.heatmap(pivot_table, ax=ax, cbar_ax=cbar_ax,
|
|
annot=True, fmt=fmt,
|
|
annot_kws={'fontsize': config.plotstyle.matrix_font.size},
|
|
cbar=cbar, vmin=vmin, vmax=vmax, cmap=colormap, cbar_kws={'label': f"{metric_legend}"})
|
|
else:
|
|
# BUILD STD TABLE
|
|
pivot_table_std = pd.pivot_table(dataframe,
|
|
columns=cols, index=idx, values=metric_name,
|
|
aggfunc=config.plot.aggfunc, fill_value=float("inf"), dropna=False
|
|
)
|
|
if float("inf") in pivot_table_std.values.flatten():
|
|
log_warn("WARNING: Not enough data to calculate the std, skipping std in plot")
|
|
|
|
pivot_table_std = (pivot_table_std * (10 ** value_exp_factor)).round(decimal_points)
|
|
|
|
annot_matrix = pivot_table.copy().astype("string")
|
|
for i in pivot_table.index:
|
|
for j in pivot_table.columns:
|
|
mean = pivot_table.loc[i, j]
|
|
std = pivot_table_std.loc[i, j]
|
|
std_string = f"\n±({round(std, decimal_points)})" if std != float("inf") else "" # type: ignore
|
|
annot_matrix.loc[i, j] = f"{round(mean, decimal_points)}{std_string}" # type: ignore
|
|
|
|
fmt = "" # cannot format like before, as we do not only have a number
|
|
|
|
return sns.heatmap(pivot_table, ax=ax, cbar_ax=cbar_ax,
|
|
annot=annot_matrix, fmt=fmt,
|
|
annot_kws={'fontsize': config.plotstyle.matrix_font.size},
|
|
cbar=cbar, vmin=vmin, vmax=vmax, cmap=colormap, cbar_kws={'label': f"{metric_legend}"})
|
|
|
|
|
|
def get_all_num_rows_and_their_names(dataframe_list: list[pd.DataFrame], config):
|
|
n_rows: list[int] = []
|
|
row_names: list[list[str]] = []
|
|
for i, df in enumerate(dataframe_list):
|
|
x_axis = config.plot.x_axis[i]
|
|
y_axis = config.plot.y_axis[0]
|
|
metrics = df["evaluation.plot.metric"].unique()
|
|
ignored_cols = [x_axis, y_axis]
|
|
ignored_cols += list(metrics)
|
|
ignored_cols += config.get("ignore_keys", [])
|
|
ignored_cols += config.get("aggregate_groups", [])
|
|
current_n_rows, current_names = get_num_rows(df, ignored_cols, config)
|
|
n_rows.append(current_n_rows)
|
|
if not current_names: # will be empty if we have only one row
|
|
current_names.append("default")
|
|
row_names.append(current_names)
|
|
|
|
return n_rows, row_names
|
|
|
|
def get_num_rows(dataframe: pd.DataFrame, ignored_cols: list[str], config: AttributeDict
|
|
) -> tuple[int, list[str]]:
|
|
"""each matrix has 2 params (on for x and y each), one value, and we aggregate over seeds;
|
|
if there are more than than these 4 parameter with different values,
|
|
we want to put that in seperate rows instead of aggregating over them.
|
|
returning: the number of rows (atleast 1) and the names of the cols"""
|
|
necesarry_rows = 0
|
|
|
|
# the user might specify a value for the groups that we should split on in <split_groups>
|
|
whitelisted_cols: list[str] | Literal["all"] = "all" # everything is whitelisted if this value stays 'all'
|
|
if isinstance(config.split_groups, list):
|
|
whitelisted_cols = config.split_groups[:]
|
|
elif config.split_groups is False:
|
|
whitelisted_cols = []
|
|
|
|
columns_with_non_unique_values = []
|
|
for col in dataframe.columns:
|
|
is_eval_key = col.startswith("evaluation.")
|
|
is_ignored = col in ignored_cols
|
|
is_whitelisted = whitelisted_cols == "all" or col in whitelisted_cols
|
|
if any([is_ignored, is_eval_key, not is_whitelisted]):
|
|
if is_whitelisted:
|
|
log_warn(f"{col} is in the whitelist, but will be ignored. Probably {col} is in both 'split_groups' and 'aggregate_groups'.")
|
|
log_debug(f"ignoring {col}")
|
|
continue
|
|
nunique = dataframe[col].nunique(dropna=False)
|
|
if nunique > 1:
|
|
log_debug(f"adding {col} since there are {nunique} unique values")
|
|
for unique_hp in dataframe[col].unique():
|
|
columns_with_non_unique_values.append(f"{col}={unique_hp}")
|
|
necesarry_rows += (nunique) # each unique parameter should be an individal plot
|
|
|
|
rows_number = max(necesarry_rows, 1)
|
|
col_names = columns_with_non_unique_values
|
|
log_debug(f"{rows_number=}")
|
|
log_debug(f"{col_names=}")
|
|
|
|
return rows_number, col_names
|
|
|
|
|
|
def find_global_vmin_vmax(dataframe_list, config):
|
|
vmin: int | float | None = None
|
|
vmax: int | float | None = None
|
|
num_cols = len(dataframe_list)
|
|
|
|
if num_cols > 1:
|
|
# all subplots should have same colors -> we need to find the limits
|
|
vmin = float('inf')
|
|
vmax = float('-inf')
|
|
|
|
for i in range(num_cols):
|
|
dataframe = dataframe_list[i]
|
|
cols = config.plot.x_axis[i]
|
|
idx = config.plot.y_axis[0]
|
|
key = config.plot.metric
|
|
|
|
pivot_table = pd.pivot_table(dataframe,
|
|
columns=cols, index=idx,
|
|
values=key,
|
|
aggfunc='mean')
|
|
|
|
min_value_present_in_current_df = pivot_table.min().min()
|
|
max_value_present_in_current_df = pivot_table.max().max()
|
|
|
|
log_debug("colorbar_limits:\n" +
|
|
f" subfigure number {i+1}, checking for metric {key}: \n" +
|
|
f" min value is {min_value_present_in_current_df},\n" +
|
|
f" max value is {max_value_present_in_current_df}")
|
|
vmin = min(vmin, min_value_present_in_current_df)
|
|
vmax = max(vmax, max_value_present_in_current_df)
|
|
|
|
return vmin, vmax
|
|
|
|
|
|
def create_figure(dataframe_list: list[pd.DataFrame], config: AttributeDict):
|
|
"""
|
|
Takes a list of dataframes. Each dataframe is processed into a column of heatmaps.
|
|
"""
|
|
num_cols: int = len(dataframe_list)
|
|
|
|
# calculate the number of rows for each dataframe
|
|
n_rows, row_names = get_all_num_rows_and_their_names(dataframe_list, config)
|
|
|
|
# Handling of the number of rows in the plot
|
|
# we could either create a full rectangular grid, or allow each subplot to nest subplots
|
|
# for nesting we would need to create subfigures instead of subplots i think
|
|
if config.split_groups is False:
|
|
n_rows_max = 1
|
|
row_names = [["default"] for _ in range(num_cols)]
|
|
else:
|
|
n_rows_max = max(n_rows)
|
|
|
|
log_debug(f"{n_rows=} and {num_cols=}")
|
|
|
|
# TODO, figsize was just hardcoded for (1, 2) grid and left to default for (1, 1) grid
|
|
# probably not worth the hazzle to create something dynamic (atleast not now)
|
|
# EDIT: it was slightly adapted to allow num rows without being completely unreadable
|
|
# margin = (num_subfigures - 1) * 0.3
|
|
# figsize=(5*n_cols + margin, 2.5)
|
|
scale = config.plotstyle.scale
|
|
if num_cols == 1 and n_rows_max > 1:
|
|
figsize = (2**3 * scale, 2 * 3 * n_rows_max * scale)
|
|
elif num_cols == 2:
|
|
# TODO: after removing cbar from left subifgure, it is squished
|
|
# there is an argument to share the legend, we should use that
|
|
figsize = (12 * scale, 5.4 * n_rows_max * scale)
|
|
elif num_cols > 2:
|
|
figsize = (12 * (num_cols / 2) * scale, 5.4 * n_rows_max * scale)
|
|
else:
|
|
figsize = None
|
|
|
|
# TODO: use seaborn FacetGrid
|
|
fig, axs = plt.subplots(n_rows_max, num_cols, figsize=figsize)
|
|
if n_rows_max == 1:
|
|
axs = [axs]
|
|
if num_cols == 1:
|
|
axs = [[ax] for ax in axs] # adapt for special case so we have unified types
|
|
|
|
# Adjust left and right margins as needed
|
|
# fig.subplots_adjust(left=0.1, right=0.9, top=0.97, hspace=0.38, bottom=0.05,wspace=0.3)
|
|
|
|
# None -> plt will chose vmin and vmax
|
|
vmin, vmax = find_global_vmin_vmax(dataframe_list, config)
|
|
|
|
for i in range(num_cols):
|
|
num_nested_subfigures: int = n_rows[i]
|
|
|
|
if not config.split_groups:
|
|
create_one_grid_element(dataframe_list, config, axs, i,
|
|
j=0,
|
|
max_i=num_cols,
|
|
max_j=0,
|
|
vmin=vmin,
|
|
vmax=vmax,
|
|
n_rows=n_rows,
|
|
row_names=row_names)
|
|
else:
|
|
for j in range(num_nested_subfigures):
|
|
create_one_grid_element(dataframe_list, config, axs, i,
|
|
j,
|
|
max_i=num_cols,
|
|
max_j=num_nested_subfigures,
|
|
vmin=vmin,
|
|
vmax=vmax,
|
|
n_rows=n_rows,
|
|
row_names=row_names)
|
|
|
|
if config.plotstyle.tight_layout:
|
|
fig.tight_layout()
|
|
# SUPTITLE (the super title on top of the whole figure in the middle)
|
|
# # TODO super title might be squished when used together with tight layout (removing for now)
|
|
# if n_rows_max > 1 or num_cols > 1:
|
|
# # set experiment name as title when multiple matrices in image
|
|
# if config.experiment_name:
|
|
# fig.suptitle(config.experiment_name)
|
|
return fig, axs
|
|
|
|
|
|
def create_one_grid_element(dataframe_list: list[pd.DataFrame], config: AttributeDict, axs,
|
|
i: int, j: int, max_i: int, max_j: int, vmin, vmax, n_rows, row_names):
|
|
"""does one 'axs' element as it is called in plt"""
|
|
num_nested_subfigures: int = n_rows[i]
|
|
name_for_additional_subplots: list[str] = row_names[i]
|
|
num_subfigures = max_i # from left to right
|
|
num_nested_subfigures = max_j # from top to bottom
|
|
dataframe = dataframe_list[i]
|
|
|
|
cols = config.plot.x_axis[i]
|
|
idx = config.plot.y_axis[0]
|
|
# only include colorbar once
|
|
include_cbar: bool = i == num_subfigures - 1
|
|
|
|
model_param = name_for_additional_subplots[j]
|
|
if model_param == "default":
|
|
current_dataframe = dataframe # we do not need to do further grouping
|
|
else:
|
|
param_name, param_value = model_param.split("=", maxsplit=1)
|
|
if pd.api.types.is_numeric_dtype(dataframe[param_name]):
|
|
param_value = float(param_value)
|
|
try:
|
|
current_dataframe = dataframe.groupby([param_name]).get_group((param_value,))
|
|
except KeyError:
|
|
log_warn(f"WARNING: was not able to groupby '{param_name}'," +
|
|
"maybe the data was created with different versions of fob; skipping this row")
|
|
log_debug(f"{param_name=}{param_value=}{dataframe.columns=}{dataframe[param_name]=}")
|
|
return False
|
|
current_plot = create_matrix_plot(current_dataframe, config,
|
|
cols, idx,
|
|
ax=axs[j][i],
|
|
cbar=include_cbar, vmin=vmin, vmax=vmax)
|
|
|
|
# LABELS
|
|
# Pretty name for label "learning_rate" => "Learning Rate"
|
|
# remove x_label of all but last row, remove y_label for all but first column
|
|
if i > 0:
|
|
current_plot.set_ylabel('', labelpad=8)
|
|
else:
|
|
current_plot.set_ylabel(pretty_name(current_plot.get_ylabel()))
|
|
if j < num_nested_subfigures - 1:
|
|
current_plot.set_xlabel('', labelpad=8)
|
|
else:
|
|
current_plot.set_xlabel(pretty_name(current_plot.get_xlabel()))
|
|
|
|
# reading optimizer and task name after grouping
|
|
df_entry = current_dataframe.iloc[0] # just get an arbitrary trial
|
|
opti_name = df_entry['optimizer.name']
|
|
task_name = df_entry['task.name']
|
|
|
|
# TITLE
|
|
# title (heading) of the heatmap: <optimname> on <taskname> (+ additional info)
|
|
title = f"{pretty_name(opti_name)} on {pretty_name(task_name)}"
|
|
if max_i > 1 or max_j > 1:
|
|
title += "" if model_param == "default" else f"\n{model_param}"
|
|
current_plot.set_title(title)
|
|
|
|
|
|
def extract_dataframes(workload_paths: List[Path], config: AttributeDict, depth: int = 1
|
|
) -> list[pd.DataFrame]:
|
|
df_list: list[pd.DataFrame] = []
|
|
num_dataframes: int = len(workload_paths)
|
|
|
|
for i in range(num_dataframes):
|
|
available_trials = get_available_trials(workload_paths[i], config, depth)
|
|
dataframe = dataframe_from_trials(available_trials, config)
|
|
df_list.append(dataframe)
|
|
|
|
return df_list
|
|
|
|
|
|
def get_output_file_path(dataframe_list: list[pd.DataFrame], config: AttributeDict, suffix: str = "") -> Path:
|
|
task_names = [df.iloc[0]["task.name"] for df in dataframe_list]
|
|
optim_names = [df.iloc[0]["optimizer.name"] for df in dataframe_list]
|
|
task_name = "_".join(sorted(set(task_names)))
|
|
optim_name = "_".join(sorted(set(optim_names)))
|
|
|
|
here = Path(__file__).parent.resolve()
|
|
|
|
output_dir = Path(config.output_dir) if config.output_dir else here
|
|
experiment_name = Path(config.experiment_name) if config.experiment_name else f"{optim_name}-{task_name}"
|
|
output_file_path = output_dir / experiment_name
|
|
|
|
return Path(f"{output_file_path}-{suffix}" if suffix else output_file_path)
|
|
|
|
|
|
def set_plotstyle(config: AttributeDict):
|
|
plt.rcParams["text.usetex"] = config.plotstyle.text.usetex
|
|
plt.rcParams["font.family"] = config.plotstyle.font.family
|
|
plt.rcParams["font.size"] = config.plotstyle.font.size
|
|
|
|
def pretty_name(name: str, pretty_names: dict | str = {}) -> str: # type: ignore pylint: disable=dangerous-default-value
|
|
"""
|
|
Tries to use a mapping for the name, else will do some general replacement.
|
|
mapping can be a directory or a filename of a yaml file with 'names' key
|
|
"""
|
|
|
|
# reading from yaml and caching the dictionary
|
|
label_file: Path = evaluation_path() / "labels.yaml"
|
|
if isinstance(pretty_names, str):
|
|
label_file = Path(pretty_names)
|
|
|
|
if pretty_names == {} or isinstance(pretty_names, str):
|
|
yaml_parser = YAMLParser()
|
|
yaml_content = yaml_parser.parse_yaml(label_file)
|
|
pretty_names: dict[str, str] = yaml_content["names"]
|
|
|
|
# applying pretty names
|
|
name_without_yaml_prefix = name.split(".")[-1]
|
|
if name in pretty_names.keys():
|
|
name = pretty_names[name]
|
|
elif name_without_yaml_prefix in pretty_names.keys():
|
|
name = pretty_names[name_without_yaml_prefix]
|
|
else:
|
|
name = name.replace('_', ' ').title()
|
|
return name
|
|
|
|
|
|
def save_csv(dfs: list[pd.DataFrame], output_filename: Path):
|
|
for i, df in enumerate(dfs):
|
|
csv_output_filename = f"{output_filename.resolve()}-{i}.csv"
|
|
log_info(f"saving raw data as {csv_output_filename}")
|
|
df.to_csv(path_or_buf=csv_output_filename, index=False)
|
|
|
|
|
|
def save_plot(fig: Figure, output_file_path: Path, file_type: str, dpi: int):
|
|
plot_output_filename = f"{output_file_path.resolve()}.{file_type}"
|
|
log_info(f"saving figure as <{plot_output_filename}>")
|
|
fig.savefig(plot_output_filename, dpi=dpi)
|
|
|
|
|
|
def save_files(fig, dfs: list[pd.DataFrame], output_file_path: Path, config: AttributeDict):
|
|
output_file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
for file_type in config.output_types:
|
|
if file_type == "csv":
|
|
save_csv(dfs, output_file_path)
|
|
elif file_type == "png" or file_type == "pdf":
|
|
save_plot(fig, output_file_path, file_type, config.plotstyle.dpi)
|
|
|
|
|
|
def clean_config(config: AttributeDict) -> AttributeDict:
|
|
"""some processing that allows the user to be lazy, shortcut for the namespace, hidden values are found and config.all_values"""
|
|
if "evaluation" in config.keys():
|
|
evaluation_config: AttributeDict = config.evaluation
|
|
evaluation_config["all_values"] = config
|
|
config = evaluation_config
|
|
else:
|
|
log_warn("there is no 'evaluation' in the yaml provided!")
|
|
if "data_dirs" in config.keys():
|
|
value_is_none = not config.data_dirs
|
|
value_has_wrong_type = not isinstance(config.data_dirs, (PathLike, str, list))
|
|
if value_is_none or value_has_wrong_type:
|
|
raise ValueError(f"Error: 'evaluation.data_dirs' was not provided correctly! check for typos in the yaml provided! value given: {config.data_dirs}")
|
|
|
|
# allow the user to write a single string instead of a list of strings
|
|
if not isinstance(config.output_types, list):
|
|
config["output_types"] = [config.output_types]
|
|
log_info("fixing value for key <config.output_types> to be a list[str]")
|
|
|
|
if not isinstance(config.data_dirs, list):
|
|
config["data_dirs"] = [Path(config.data_dirs)]
|
|
log_info("fixing value for key <config.data_dirs> to be a list[Path]")
|
|
|
|
# x_axis
|
|
if not isinstance(config.plot.x_axis, list):
|
|
config["plot"]["x_axis"] = [config.plot.x_axis]
|
|
log_info("fixing value for key <config.plot.x_axis> to be a list[str]")
|
|
if len(config.plot.x_axis) < len(config.data_dirs):
|
|
# use same x axis for all if only one given
|
|
missing_elements = len(config.data_dirs) - len(config.plot.x_axis)
|
|
config["plot"]["x_axis"] += repeat(config.plot.x_axis[0], missing_elements)
|
|
|
|
# y_axis
|
|
if not isinstance(config.plot.y_axis, list):
|
|
config["plot"]["y_axis"] = [config.plot.y_axis]
|
|
log_info("fixing value for key <config.plot.y_axis> to be a list[str]")
|
|
if len(config.plot.y_axis) < len(config.data_dirs):
|
|
# use same x axis for all if only one given
|
|
missing_elements = len(config.data_dirs) - len(config.plot.y_axis)
|
|
config["plot"]["y_axis"] += repeat(config.plot.y_axis[0], missing_elements)
|
|
|
|
return config
|
|
|
|
|
|
def main(config: AttributeDict):
|
|
config = clean_config(config) # sets config to config.evaluation, cleans some data
|
|
workloads: List[Path] = [Path(name) for name in config.data_dirs]
|
|
log_debug(f"{workloads}=")
|
|
|
|
set_plotstyle(config)
|
|
|
|
dfs = extract_dataframes(workloads, depth=config.depth, config=config)
|
|
fig, _ = create_figure(dfs, config)
|
|
|
|
output_file_path = get_output_file_path(dfs, config)
|
|
|
|
save_files(fig, dfs, output_file_path, config)
|