import os
import pandas as pd
import matplotlib.pyplot as plt

# Set up seaborn, our plotting library
import seaborn as sns # Import seaborn
sns.set_theme(style="white", palette=None) # Set plotting theme

# Colormaps
from matplotlib.colors import Normalize
from matplotlib.cm import get_cmap

def get_mouse_data():
    df = pd.read_excel('data/Mouse_Data_extended.xlsx')
    weight_columns = ['Weight_wk1', 'Weight_wk2', 'Weight_wk3', 'Weight_wk4', 
                  'Weight_wk5', 'Weight_wk6', 'Weight_wk7', 'Weight_wk8']

    bgl_columns = ['BGL_wk1', 'BGL_wk2', 'BGL_wk3', 'BGL_wk4', 
                'BGL_wk5', 'BGL_wk6', 'BGL_wk7', 'BGL_wk8']

    df_weight_melted = df.melt(id_vars=['ID', 'Sex', 'Genotype', 'Ozempic_Dose_(mg)', 'Diet', 'Age_(Weeks)', 
                                        'Final_Weight_(g)', 'Liver_Weight_(mg)', 'Final_BGL_(mg/dL)', 
                                        'Liver_Pathology_Note'],
                            value_vars=weight_columns,
                            var_name='Week', value_name='Weight_(g)')

    df_bgl_melted = df.melt(id_vars=['ID', 'Sex', 'Genotype', 'Ozempic_Dose_(mg)', 'Diet', 'Age_(Weeks)', 
                                    'Final_Weight_(g)', 'Liver_Weight_(mg)', 'Final_BGL_(mg/dL)', 
                                    'Liver_Pathology_Note'],
                            value_vars=bgl_columns,
                            var_name='Week', value_name='BGL_(mg/dL)')

    # Extract week number from the column names
    df_weight_melted['Week'] = df_weight_melted['Week'].str.extract('(\d+)').astype(int)
    df_bgl_melted['Week'] = df_bgl_melted['Week'].str.extract('(\d+)').astype(int)

    # Merge the two melted DataFrames on the common columns
    df_melted = pd.merge(df_weight_melted, df_bgl_melted, 
                        on=['ID', 'Sex', 'Genotype', 'Ozempic_Dose_(mg)', 'Diet', 'Age_(Weeks)', 
                            'Final_Weight_(g)', 'Liver_Weight_(mg)', 'Final_BGL_(mg/dL)', 
                            'Liver_Pathology_Note', 'Week'])

    # Rename the columns
    df_melted.columns = ['ID', 'Sex', 'Genotype', 'Ozempic_Dose_mg', 'Diet', 'Age_Weeks', 'Final_Weight_g', 'Liver_Weight_mg', 'Final_BGL_mg_dL', 'Liver_Pathology_Note', 'Week', 'Weight_g', 'BGL_mg_dL']

    return df_melted

def show_calculator_solution():
    with open('solutions/calculator_solution.txt') as f:
        lines = f.readlines()
    for line in lines:
        print(line)

def make_kde_plot(dataframe, diet, sex, colormap):
    data = dataframe.query(f'Diet == "{diet}" and Sex == "{sex}"')
    g = sns.FacetGrid(data = data, col = 'Ozempic_Dose_mg', row = 'Genotype', hue = 'Week', palette=f'{colormap}', margin_titles=True)
    g.map(sns.kdeplot, 'BGL_mg_dL', 'Weight_g', label = 'Week')
    plt.show()