# This file is part of Awali.
# Copyright 2016-2019 Sylvain Lombardy, Victor Marsault, Jacques Sakarovitch
#
# Awali is a free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import unittest, sys
import semiring
import awalipy as vr


global A,B,C,D
A = None
B = None
C = None
D = None

class AutomatonEditionTests(unittest.TestCase):
    #def run(self, result=None):
            #if result.failures or result.errors:
                    #print "Abort remaining tests"
            #else:
                    #super(unittest.TestCase, self).run(result)

    def assertAutomatonSynctacticEquality(self, X, Y):
        self.assertEqual(X.states(), Y.states())
        self.assertEqual(sorted(X.initial_states()), sorted(Y.initial_states()))
        self.assertEqual(sorted(X.final_states()), sorted(Y.final_states()))
        self.assertEqual(X.transitions(), Y.transitions())
        for i in Y.transitions():
            self.assertEqual(X.src_of(i), Y.src_of(i))
            self.assertEqual(X.dst_of(i), Y.dst_of(i))
            self.assertEqual(X.label_of(i), Y.label_of(i))

    def test_01_creation(self):
        global A,B
        A = vr.Automaton('ab', semiring.name)
        B = vr.make_automaton('ab', semiring.name)
        #self.assertEqual(A.get_static_context(), 'lal_char_'+semiring.name.lower())
        #self.assertEqual(vr.get_static_context(B), 'lal_char_'+semiring.name.lower())

    def test_02_alphabet(self):
        self.assertEqual(A.alphabet(), 'ab')
        self.assertEqual(vr.alphabet(B), 'ab')
        for c in ['a', 'b']:
            self.assertTrue(A.has_letter(c))
            self.assertTrue(vr.has_letter(B, c))

    def test_03_weightset(self):
        self.assertEqual(A.weight_one(), semiring.values[1])
        self.assertEqual(A.weight_zero(), semiring.values[0])
        #self.assertEqual(A.get_weightset().lower(), semiring.name.lower())
        #self.assertEqual(vr.get_weightset(B).lower(), semiring.name.lower())
        self.assertEqual(vr.weight_one(B), semiring.values[1])
        self.assertEqual(vr.weight_zero(B), semiring.values[0])

    def test_04_states(self):
        A.add_state()
        A.add_state()
        A.add_state()
        self.assertEquals(A.states(), [0,1,2])
        self.assertEquals(A.num_states(), 3)
        for i in range(3):
            self.assertTrue(A.has_state(i))
        for i in range(3,10):
            self.assertFalse(A.has_state(i))
        vr.add_state(B)
        vr.add_state(B)
        vr.add_state(B)
        self.assertEquals(vr.states(B), [0,1,2])
        self.assertEquals(vr.num_states(B), 3)
        for i in range(3):
            self.assertTrue(vr.has_state(B,i))
        for i in range(3,10):
            self.assertFalse(vr.has_state(B,i))

    def test_05_final_initial(self):
        A.set_initial(0)
        A.add_initial(2, semiring.values[4])
        A.add_initial(2, semiring.values[3])
        A.add_initial(1, semiring.values[1])
        A.unset_initial(0)
        A.add_initial(1, semiring.values[0])
        A.set_final(0)
        A.add_final(1)
        A.set_final(1, semiring.values[2])
        A.unset_final(2)
        if (semiring.name.lower() == 'f2'):
            self.assertFalse(A.is_initial(2))
            A.set_initial(2)
        self.assertEquals(A.initial_states(), [1,2])
        self.assertEquals(A.final_states(), [0,1])
        self.assertEquals(A.num_initials(), 2)
        self.assertEquals(A.num_finals(), 2)
        vr.set_initial(B, 0)
        vr.add_initial(B, 2, semiring.values[4])
        vr.add_initial(B, 2, semiring.values[3])
        vr.add_initial(B, 1, semiring.values[1])
        vr.unset_initial(B, 0)
        vr.add_initial(B, 1, semiring.values[0])
        vr.set_final(B, 0)
        vr.add_final(B, 1)
        vr.set_final(B, 1, semiring.values[2])
        vr.unset_final(B, 2)
        if (semiring.name.lower() == 'f2'):
            self.assertFalse(vr.is_initial(B,2))
            vr.set_initial(B, 2)
        self.assertEquals(vr.initial_states(B), [1,2])
        self.assertEquals(vr.final_states(B), [0,1])
        self.assertEquals(vr.num_initials(B), 2)
        self.assertEquals(vr.num_finals(B), 2)
        for i in [1,2]:
            self.assertTrue(A.is_initial(i))
            self.assertTrue(vr.is_initial(B,i))
        self.assertEqual(A.get_initial_weight(1), semiring.values[1])
        self.assertEqual(vr.get_initial_weight(B,1), semiring.values[1])
        self.assertEqual(A.get_initial_weight(2), semiring.added)
        self.assertEqual(vr.get_initial_weight(B,2), semiring.added)
        self.assertFalse(A.is_initial(0))
        self.assertFalse(vr.is_initial(B,0))
        self.assertEqual(A.get_initial_weight(0), semiring.values[0])
        self.assertEqual(vr.get_initial_weight(B,0), semiring.values[0])
        for i in [0,1]:
            self.assertTrue(A.is_final(i))
            self.assertTrue(vr.is_final(B, i))
        self.assertEqual(A.get_final_weight(0), semiring.values[1])
        self.assertEqual(vr.get_final_weight(B,0), semiring.values[1])
        self.assertEqual(A.get_final_weight(1), semiring.values[2])
        self.assertEqual(vr.get_final_weight(B,1), semiring.values[2])
        self.assertFalse(A.is_final(2))
        self.assertFalse(vr.is_final(B, 2))
        self.assertEqual(A.get_final_weight(2), semiring.values[0])
        self.assertEqual(vr.get_final_weight(B,2), semiring.values[0])

    def test_06_transitions(self):
        L = [ [1, 1, 'b'], [0,1,'a'], [0,0,'a'], [0,0,'b'],  [0,2,'a'], [1, 1, 'b'],  \
              [1,2,'b'], [2,2,'a'], [0,2,'a'], ]
        #last transition is redundant hence not semiring.added
        j=0
        for i in L:
            A.add_transition(i[0], i[1], i[2], semiring.values[j])
            vr.add_transition(B, i[0], i[1], i[2], semiring.values[j])
            j=(j+1)%5
        if (semiring.name.lower() == 'f2'):
            self.assertFalse(A.has_transition(0,2,'a'))
            self.assertFalse(vr.has_transition(B,0,2,'a'))
            A.add_transition(0,2,'a')
            vr.add_transition(B, 0,2,'a')
        L.pop(5);
        L.pop(0);
        j=0
        for i in A.transitions():
            self.assertTrue(A.has_transition(i))
            self.assertEqual(A.src_of(i), L[j][0])
            self.assertEqual(A.dst_of(i), L[j][1])
            self.assertEqual(A.label_of(i), L[j][2])
            self.assertTrue(A.has_transition(L[j][0], L[j][1], L[j][2]))
            self.assertEqual(A.get_transition(L[j][0], L[j][1], L[j][2]), i)
            if (j < 3):
                self.assertEqual(A.weight_of(i), semiring.values[j+1]) #+1 since we deleted 0
            elif (j == 3):
                self.assertEqual(A.weight_of(i), semiring.added)
            else:
                self.assertEqual(A.weight_of(i), semiring.values[j+2-5])
            j=j+1

        for i in range(10,10):
            self.assertFalse(A.has_transition(i))
        with self.assertRaises(TypeError):
            A.has_transition(0,1)
        self.assertEqual(A.predecessors(1), [0])
        self.assertEqual(A.predecessors(1, 'b'), [])
        self.assertEqual(sorted(A.predecessors(2, 'a')), [0,2])
        self.assertEqual(sorted(A.successors(2)), [2])
        self.assertEqual(sorted(A.successors(0)), [0,0,1,2])
        self.assertEqual(sorted(A.successors(0, 'a')), [0,1,2])

        j=0
        for i in vr.transitions(B):
            self.assertTrue(vr.has_transition(B, i))
            self.assertEqual(vr.src_of(B, i), L[j][0])
            self.assertEqual(vr.dst_of(B, i), L[j][1])
            self.assertEqual(vr.label_of(B, i), L[j][2])
            self.assertTrue(vr.has_transition(B, L[j][0], L[j][1], L[j][2]))
            self.assertEqual(vr.get_transition(B, L[j][0], L[j][1], L[j][2]), i)
            j=j+1
        for i in range(10,10):
            self.assertFalse(vr.has_transition(B, i))
        self.assertFalse(vr.has_transition(B,2,2, 'b'))
        with self.assertRaises(TypeError):
            vr.has_transition(B,0,1)
        self.assertEqual(vr.predecessors(B, 1), [0])
        self.assertEqual(vr.predecessors(B, 1, 'b'), [])
        self.assertEqual(sorted(vr.predecessors(B, 2, 'a')), [0,2])
        self.assertEqual(sorted(vr.successors(B, 2)), [2])
        self.assertEqual(sorted(vr.successors(B, 0)), [0,0,1,2])
        self.assertEqual(sorted(vr.successors(B, 0, 'a')), [0,1,2])

        self.assertAutomatonSynctacticEquality(A,B)

    #def test_07_eps_transitions(self):
        #A.add_state()
        #A.allow_eps_transition()
        #A.add_eps_transition(0, 3, semiring.values[2])
        #A.add_eps_transition(0, 3, semiring.values[3])
        #A.set_eps_transition(3, 1, semiring.values[2])
        #A.set_eps_transition(3, 1, semiring.values[4])
        #if (semiring.lower() != "f2"):

            #self.assertEqual(A.weight_of(tr_id_1), semiring.added)
        #else:
            #with self.assertRaises(TypeError):
                #tr_id_1= A.get_transition(0, 3, "")
        #tr_id_2= A.get_transition(3, 1, "")
        #self.assertEqual(A.weight_of(tr_id_2), semiring.values[4])

        #vr.add_state(B)
        #vr.add_eps_transition(B, 0, 3, semiring.values[2])
        #vr.add_eps_transition(B, 0, 3, semiring.values[3])
        #vr.set_eps_transition(B, 3, 1, semiring.values[2])
        #vr.set_eps_transition(B, 3, 1, semiring.values[4])

    def test_07_copy(self):
        global C, D
        C = A.copy()
        self.assertAutomatonSynctacticEquality(A, C)
        D = vr.copy(B)
        self.assertAutomatonSynctacticEquality(B, D)
        self.assertAutomatonSynctacticEquality(C, D)

    def test_08_del(self):
        A.del_state(1)
        self.assertEqual(A.states(), [0,2])
        for i in A.transitions():
            self.assertNotEqual(A.src_of(i),1)
            self.assertNotEqual(A.dst_of(i),1)
        A.del_transition(9)
        self.assertFalse(A.has_transition(9))
        A.del_transition(0,2,'a')
        self.assertFalse(A.has_transition(0,2,'a'))
        self.assertTrue(A.has_transition(0,0,'a'))
        self.assertTrue(A.has_transition(0,0,'b'))
        A.del_transition(0,0)
        self.assertFalse(A.has_transition(0,0,'a'))
        self.assertFalse(A.has_transition(0,0,'b'))
        with self.assertRaises(ValueError):
            A.del_state(1)
        with self.assertRaises(ValueError):
            A.del_transition(1,1,'a')
        with self.assertRaises(ValueError):
            A.del_transition(2,1,'a')
        self.assertEqual(A.transitions(), [])
        A.del_transition(2,2,'b')
        self.assertEqual(A.transitions(), [])

        vr.del_state(B, 1)
        self.assertEqual(vr.states(B), [0,2])
        for i in vr.transitions(B):
            self.assertNotEqual(vr.src_of(B, i), 1)
            self.assertNotEqual(vr.dst_of(B, i), 1)
        vr.del_transition(B, 9)
        self.assertFalse(vr.has_transition(B, 9))
        vr.del_transition(B, 0, 2, 'a')
        self.assertFalse(vr.has_transition(B, 0, 2, 'a'))
        self.assertTrue(vr.has_transition(B, 0, 0, 'a'))
        self.assertTrue(vr.has_transition(B, 0, 0, 'b'))
        vr.del_transition(B, 0, 0)
        self.assertFalse(vr.has_transition(B, 0, 0, 'a'))
        self.assertFalse(vr.has_transition(B, 0, 0, 'b'))
        with self.assertRaises(ValueError):
            vr.del_state(B, 1)
        with self.assertRaises(ValueError):
            vr.del_transition(B, 1, 1, 'a')
        with self.assertRaises(ValueError):
            vr.del_transition(B, 2, 1, 'a')
        self.assertEqual(vr.transitions(B), [])
        vr.del_transition(B, 2, 2, 'b')
        self.assertEqual(vr.transitions(B), [])
        self.assertAutomatonSynctacticEquality(A,B)


    #def test_99_rerun_all_with_lowercase_weightset(self):
        #global A, B
        #semiring.name = semiring.name.lower()
        #self.test_01_creation()
        #self.test_02_alphabet()
        #self.test_03_weightset()
        #self.test_04_states()
        #self.test_05_final_initial()
        #self.test_06_transitions()
        #self.test_08_copy()
        #self.test_09_del()











while(semiring.next()):
    print >> sys.stderr, ""
    print >> sys.stderr, "============================================================"
    print >> sys.stderr, "\tEdition tests for automata over weightset " + semiring.name + "."
    print >> sys.stderr, "============================================================"
    suite = unittest.TestLoader().loadTestsFromTestCase(AutomatonEditionTests)
    unittest.TextTestRunner(verbosity=2).run(suite)
