from model import make_fom, make_param_space
from rombuild import load_remote
import model

import pymor.basic as pmb
import numpy as np

from matplotlib import pyplot as plt

import argparse
from tqdm import tqdm


def rel_error(fom_sol, rom_sol, product):
    '''Compute list of relative error between fom_sol and rom_sol for the given product'''
    return (rom_sol - fom_sol).norm(product)/fom_sol.norm(product)


def compute_proj_errors_orth_basis(basis, V, product):
    '''Compute the list of maximum projection errors made by projecting vectors of V onto the basis 
    The basis is made by taking the N first vectors of basis and letting N grow'''
    errors = []
    for N in tqdm(range(5, len(basis) + 1, 5)):
        v = V.inner(basis[:N], product=product)
        V_proj = basis[:N].lincomb(v)
        errors.append(np.max((V - V_proj).norm(product)))
    return errors


def plot_proj_error(basis, error, filename):
    '''Plot and savefigure of projection errors'''
    plt.figure(figsize=(12, 8))
    plt.semilogy(list(range(5, len(basis) + 1, 5)), error, label='POD')
    plt.title('error according to the size of the basis')
    plt.legend()
    plt.savefig(filename)


def plot_rel_error(error, filename):
    '''Plot and savefigure of relative errors'''
    plt.figure(figsize=(12, 8))
    plt.scatter(x= np.arange(len(error)), y=error*100)
    plt.title('relative error for each parameter value in %')
    plt.savefig(filename)


def make_parser():
    parser = argparse.ArgumentParser(description='Script that will compare the result of the fom and rom',
                                     prefix_chars='-')
    parser.add_argument('-r', '--rom_dir',
                        help='set the directory that contains the rom. It must contain rom.out and basis.h5',
                        type=str,
                        default='.')
    parser.add_argument('-n', '--nb_files',
                        help='number of result files',
                        type=int,
                        default=30)
    parser.add_argument('-i', '--input_fom_dir',
                        help='input directory for the fom solutions. the files should be dir/u0_c.h5, ..., dir/u{n}_c.h5',
                        type=str,
                        default='train')
    parser.add_argument('-u', '--input_rom_sol',
                        help='input file for the rom solutions',
                        type=str,
                        default='u.out')
    return parser

# Script that will compare the result of the fom and rom. 
# It loads previously solved solutions and looks for the worst error.
# it also displays the evolution of the error in accordance to the number of element in the reduction basis
# Options are:
#   -r   set the directory that contains the rom. It must contain rom.out and basis.h5
#   -n   number of result files (defaults to 30)
#   -i   input directory for the fom solutions. the files should be dir/u0_c.h5, ..., dir/u{n}_c.h5 (defaults to 'train'). 
#   -u   input file for the rom solutions. Defaults to u.out
if __name__=='__main__':
    args = make_parser().parse_args()
    
    fom = make_fom('mesh.xml')
    param_space = make_param_space()
    
    rom_dir = args.rom_dir
    rom = pmb.load(open(f'{rom_dir}/rom.out', 'rb'))
    rb_basis = model.load_sol_list(f'{rom_dir}/pod_basis.out', fom.solution_space, radica='basis')
    reductor = pmb.InstationaryRBReductor(fom, RB=rb_basis, product=fom.h1_0_semi_product)
    
    U_fom = load_remote(fom, args.nb_files, args.input_fom_dir)
    print(f'\nloaded {len(U_fom)} vectors of dim {U_fom.dim} from fom')
    
    rom_file = args.input_rom_sol
    u_rom = pmb.load(open(rom_file, 'rb'))
    print(f'loaded {len(u_rom)} vectors of dim {u_rom.dim} from rom')
    U_rom = reductor.reconstruct(u_rom)
    print(f'reconstructed {len(U_rom)} vectors of dim {U_rom.dim} from rom')
        
    pod_errors = compute_proj_errors_orth_basis(rb_basis, U_fom, fom.h1_0_semi_product)
    rel_errors = rel_error(U_fom, U_rom, fom.h1_0_semi_product)
    
    print('\n')
    print( '***********************************')
    print(f'***  max rel error : {np.max(rel_errors)*100:.6f}%  ***')
    print( '***********************************')
    print('see output bucket for plot of:\n' +
          '\t· error with respect to reduction basis size\n'+
          '\t· relative error for each parameter value')
    
    plot_proj_error(rb_basis, pod_errors, 'error_wrt_basis_size.png')
    plot_rel_error(rel_errors, 'rel_error_wrt_parameter.png')
