|
|
import pandas as pd |
|
|
import os |
|
|
|
|
|
|
|
|
def create_plots(path_file,path_save): |
|
|
df = pd.read_csv(path_file) |
|
|
|
|
|
|
|
|
|
|
|
generate_plot(df, 'Accuracy replies', path_save, 'figure_replies.png') |
|
|
generate_plot(df, 'Accuracy sources', path_save, 'figure_sources.png') |
|
|
generate_all_plot(df, path_save, 'figure_total.png') |
|
|
|
|
|
def create_plots_list_values(list_path_file,path_save): |
|
|
df_list = [] |
|
|
for path_file in list_path_file: |
|
|
df_list += [pd.read_csv(path_file)] |
|
|
|
|
|
output_replies, output_sources = get_mean_std(df_list) |
|
|
|
|
|
generate_plot(output_replies, 'mean', path_save, 'figure_replies_mean.png', std = True) |
|
|
generate_plot(output_sources, 'mean', path_save, 'figure_sources_mean.png', std = True) |
|
|
|
|
|
|
|
|
output_replies = output_replies.set_index('Name') |
|
|
output_sources = output_sources.set_index('Name') |
|
|
df_merge = pd.concat([output_replies, output_sources], axis=1) |
|
|
df_merge.columns = ['mean_replies', 'std_r', 'mean_sources', 'std_s'] |
|
|
df_merge = df_merge.sort_values(by='mean_replies', ascending=False) |
|
|
print(df_merge) |
|
|
ax = df_merge.plot.barh(y=['mean_replies', 'mean_sources'], xerr=df_merge[['std_r', 'std_s']].T.values) |
|
|
fig = ax.get_figure() |
|
|
fig.savefig(os.path.join(path_save, 'figure_total_mean.png'), bbox_inches='tight') |
|
|
|
|
|
|
|
|
def get_mean_std(df_list): |
|
|
df_merge = pd.concat(df_list) |
|
|
output_replies = get_mean_std_from_merge(df_merge, 'Accuracy replies') |
|
|
output_sources = get_mean_std_from_merge(df_merge, 'Accuracy sources') |
|
|
return output_replies, output_sources |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_mean_std_from_merge(df_merge, accuracy_name): |
|
|
output = df_merge.groupby(['Name'], as_index=False).agg({accuracy_name:['mean','std']}) |
|
|
output.columns = output.columns.droplevel() |
|
|
output = output.rename(columns={'': "Name"}) |
|
|
return output |
|
|
|
|
|
def generate_plot(df, y, path_save, name_figure, std = False): |
|
|
df = df.sort_values(by=y, ascending=False) |
|
|
if std: |
|
|
ax = df.plot.barh(x='Name', y=y, xerr='std') |
|
|
else: |
|
|
ax = df.plot.barh(x='Name', y=y) |
|
|
fig = ax.get_figure() |
|
|
fig.savefig(os.path.join(path_save, name_figure), bbox_inches='tight') |
|
|
|
|
|
def generate_all_plot(df, path_save, name_figure): |
|
|
df = df.sort_values(by='Accuracy sources', ascending=False) |
|
|
df = df.set_index('Name') |
|
|
ax = df.plot.barh() |
|
|
fig = ax.get_figure() |
|
|
fig.savefig(os.path.join(path_save, name_figure), bbox_inches='tight') |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
path_file = "results_QA_gen_V5.csv" |
|
|
path_save = "plots/QA_generated_V5" |
|
|
|
|
|
if not os.path.exists(path_save): |
|
|
os.makedirs(path_save) |
|
|
|
|
|
if not isinstance(path_file, list): |
|
|
create_plots(path_file, path_save) |
|
|
else: |
|
|
create_plots_list_values(path_file,path_save) |