import os
import unittest

from loss_by_birth_year import load_data, compute_loss_since_year, compute_loss_over_range, simulate_loss_by_birth_year, age_of_work

class TestLossByBirthYear(unittest.TestCase):
    def setUp(self) -> None:
        fname = os.path.dirname(__file__)
        fname = os.path.join(fname, "snb-data-plkopr-en-selection-20210421_0900.csv")
        with open(fname) as f:
            self.data = load_data(f.read())
        return super().setUp()

    def test_load_worked(self):
        self.assertTrue(self.data)
        self.assertIsInstance(self.data, list)
    
    def test_compute_loss_since_year_zero(self):
        have = compute_loss_since_year(self.data, 1922, 1922)
        want = 0.0
        self.assertEqual(want, have)

    def test_compute_loss_since_year_valid(self):
        have = compute_loss_since_year(self.data, 1922, 2021)
        want = (100.0553 - 17.791231) / 17.791231
        self.assertEqual(want, have)

    def test_compute_loss_over_range_trivial(self):
        have = compute_loss_over_range(self.data, 1923, 1924)
        want = [compute_loss_since_year(self.data, 1923, 1924)]
        self.assertEqual(want, have)

    def test_compute_loss_over_range_series(self):
        start, stop = 1923, 1926
        have = compute_loss_over_range(self.data, start, stop)
        want = [
            (16.10145965-15.62565066)/15.62565066, # 23 -> 26
            (16.10145965-16.41218801)/16.41218801, # 24 -> 26
            (16.10145965-16.48030236)/16.48030236, # 25 -> 26
            ]
        self.assertEqual(want, have)

    def test_simulate_loss_by_birth_year(self):
        birth_year = 1980
        start_work = birth_year + age_of_work
        self.assertEqual(start_work, 2001)
        avg_monthly_saving = 1289
        have_inflation, have_total = simulate_loss_by_birth_year(self.data, birth_year, simulation_end_year=2021, monthly_average_savings=avg_monthly_saving)
        # check inflation
        want_inflation_series = compute_loss_over_range(self.data, start_work, 2021)
        want_inflation = sum(want_inflation_series)/len(want_inflation_series)
        self.assertAlmostEqual(want_inflation, have_inflation)
        # check value
        want_total_series = [ (12*avg_monthly_saving)/(1+i) for i in want_inflation_series ]
        want_total = sum([saving*inflation for saving, inflation in zip(want_total_series, want_inflation_series)])
        self.assertAlmostEqual(want_total, have_total)

if __name__ == '__main__':
    unittest.main()