Agent_UB / plot_results.py
t-pris's picture
Upload folder using huggingface_hub
7b295db verified
import pandas as pd
import os
# import matplotlib
def create_plots(path_file,path_save):
df = pd.read_csv(path_file)
# df_reply = df[['Name', 'Accuracy replies']]
# df_sources = df[['Name', 'Accuracy sources']]
generate_plot(df, 'Accuracy replies', path_save, 'figure_replies.png')
generate_plot(df, 'Accuracy sources', path_save, 'figure_sources.png')
generate_all_plot(df, path_save, 'figure_total.png')
def create_plots_list_values(list_path_file,path_save):
df_list = []
for path_file in list_path_file:
df_list += [pd.read_csv(path_file)]
output_replies, output_sources = get_mean_std(df_list)
generate_plot(output_replies, 'mean', path_save, 'figure_replies_mean.png', std = True)
generate_plot(output_sources, 'mean', path_save, 'figure_sources_mean.png', std = True)
output_replies = output_replies.set_index('Name')
output_sources = output_sources.set_index('Name')
df_merge = pd.concat([output_replies, output_sources], axis=1)
df_merge.columns = ['mean_replies', 'std_r', 'mean_sources', 'std_s']
df_merge = df_merge.sort_values(by='mean_replies', ascending=False)
print(df_merge)
ax = df_merge.plot.barh(y=['mean_replies', 'mean_sources'], xerr=df_merge[['std_r', 'std_s']].T.values)
fig = ax.get_figure()
fig.savefig(os.path.join(path_save, 'figure_total_mean.png'), bbox_inches='tight')
# generate_all_plot(df, path_save, 'figure_total.png')
def get_mean_std(df_list):
df_merge = pd.concat(df_list)
output_replies = get_mean_std_from_merge(df_merge, 'Accuracy replies')
output_sources = get_mean_std_from_merge(df_merge, 'Accuracy sources')
return output_replies, output_sources
# ax = output_sources.plot.barh(x='Name', y='mean', xerr='std')
# fig = ax.get_figure()
# fig.savefig(os.path.join("plots/QA_generated+Nicolas", "test2.jpg"), bbox_inches='tight')
def get_mean_std_from_merge(df_merge, accuracy_name):
output = df_merge.groupby(['Name'], as_index=False).agg({accuracy_name:['mean','std']})
output.columns = output.columns.droplevel()
output = output.rename(columns={'': "Name"})
return output
def generate_plot(df, y, path_save, name_figure, std = False):
df = df.sort_values(by=y, ascending=False)
if std:
ax = df.plot.barh(x='Name', y=y, xerr='std')
else:
ax = df.plot.barh(x='Name', y=y)
fig = ax.get_figure()
fig.savefig(os.path.join(path_save, name_figure), bbox_inches='tight')
def generate_all_plot(df, path_save, name_figure):
df = df.sort_values(by='Accuracy sources', ascending=False)
df = df.set_index('Name')
ax = df.plot.barh()
fig = ax.get_figure()
fig.savefig(os.path.join(path_save, name_figure), bbox_inches='tight')
if __name__ == "__main__":
# path_file = "results_QA_Nicolas.csv"
# path_save = "plots/QA_Nicolas"
# path_file = "results_QA_generated.csv"
# path_save = "plots/QA_generated"
# path_file = ["results_QA_gen_V5.csv", "results_QA_Nicolas_V5.csv"]
# path_save = "plots/QA_generated+Nicolas_V5"
path_file = "results_QA_gen_V5.csv"
path_save = "plots/QA_generated_V5"
if not os.path.exists(path_save):
os.makedirs(path_save)
if not isinstance(path_file, list):
create_plots(path_file, path_save)
else:
create_plots_list_values(path_file,path_save)