import os
import shutil

import caffe
import lmdb
import numpy as np
import sklearn.preprocessing as pre
from sklearn.datasets import load_iris

import iris_tuto as it


def create_iris_lmdb(lmdb_path):
    print("Creating LMDB database")
    if os.path.isdir(lmdb_path):
        print("Removing old database")
        shutil.rmtree(lmdb_path)
    db = lmdb.open(lmdb_path, map_size=2 ** 20)
    data = load_iris()
    inputs = data['data'].reshape((150, 1, 1, 4))

    enc = pre.OneHotEncoder()
    enc.fit([[0], [1], [2]])
    with db.begin(write=True) as txn:
        for count in range(inputs.shape[0]):
            label = data['target'][count]
            coded_label = enc.transform(label).toarray()
            datum = caffe.io.array_to_datum(inputs[count], label=coded_label)
            txn.put('{:08}'.format(count), datum.SerializeToString())
    print("LMDB database created successfully.")


def main():
    lmdb_path = "lmdb"
    solver_prototxt_filename = 'iris_solver_lmdb.prototxt'
    train_test_prototxt_filename = 'iris_train_test_lmdb.prototxt'
    deploy_prototxt_filename = 'iris_deploy.prototxt'
    caffemodel_filename = 'snapshots/iris_lmdb_iter_1000000.caffemodel'

    create_iris_lmdb(lmdb_path)
    it.train(solver_prototxt_filename)

    # Get predicted outputs
    input_eg = np.array([[5.1, 3.5, 1.4, 0.2]])
    print(it.get_predicted_output(deploy_prototxt_filename, caffemodel_filename, input_eg))

    # Print network
    it.print_network(deploy_prototxt_filename)
    it.print_network(train_test_prototxt_filename)
    it.print_network_weights(train_test_prototxt_filename, caffemodel_filename)

    Compute performance metrics
    inputs = data['input']
    outputs = it.get_predicted_outputs(deploy_prototxt_filename, caffemodel_filename, inputs)
    it.get_accuracy(data['output'], outputs)


if __name__ == "__main__":
    main()
