from main import get_env, set_log_level, load_training_set
import model

import pymor.basic as pmb

from tqdm import tqdm
import argparse


def make_parser():
    parser = argparse.ArgumentParser(description='loads a reduced order model,' \
                                     'solves it for the specified parameter and write the result to a hdf5 file',
                                     prefix_chars='-')
    parser.add_argument('-o', '--output_file',
                        help='output file name',
                        type=str,
                        default='u.out')
    parser.add_argument('-r', '--rom_dir',
                        help='set the directory that contains the rom. It must contain rom.out and basis.h5',
                        type=str,
                        default='.')
    parser.add_argument('-i', '--input_param_file',
                        help='set input parameter file',
                        type=str,
                        required=True)
    parser.add_argument('-f', '--full_order_save',
                        help='also save the full-order results',
                        action='store_true')

    return parser


# Script that will be executed on Qarnot's node. 
# It loads a reduced order model, solves it for the specified parameter and write the result to a hdf5 file.
# It is made to run on multiple Qarnot instance.
# Options are the following:
#   -o   set output file name. Defaults to u.out.
#   -r   set the directory that contains the rom. It must contain rom.out and basis.h5
#   -f   to also save the full-order results.
if __name__ == '__main__':
    iid, icount = get_env()
    set_log_level(iid)
    
    args = make_parser().parse_args()
    
    param_space = model.make_param_space()
    
    rom = pmb.load(open(f'{args.rom_dir}/rom.out', 'rb'))    
    
    training_set = load_training_set(args, param_space, iid, icount)
    
    U = rom.solution_space.empty()
    t_used = model.used_time()
    
    # instance number 0 displays task advancement through tqdm for better monitoring
    if iid == 0:
        for mu in tqdm(training_set):
            U.append(rom.solve(mu)[t_used])
    else:
        for mu in training_set:
            U.append(rom.solve(mu)[t_used])
    
    pmb.dump(U, open(args.output_file, 'wb'))
    
    if args.full_order_save:
        from main import make_fom
        fom = make_fom('mesh.xml')
        rb_basis = model.load_sol_list(f'{args.rom_dir}/pod_basis.out', fom.solution_space, radica='basis')
        reductor = pmb.InstationaryRBReductor(fom, RB=rb_basis, product=fom.h1_0_semi_product)
        U_rec = fom.solution_space.empty()
        for u in U:
            U_rec.append(reductor.reconstruct(u))
        model.dump_sol_list(f'{args.output_file.split(".")[0]}.h5', U_rec, 'u')