module SimpleTheoremProver where

import Data.List (nubBy, find)
import Data.Set (fromList, union)

-- | Based on https://bor0.wordpress.com/2018/08/07/simple-theorem-prover-in-racket/.

-- | A rule is a way to change a theorem.
data Rule a = Rule { name :: String , function :: a -> a }

instance Show (Rule a) where
    show (Rule a _) = a

-- | A theorem is consisted of an initial axiom and rules (ordered set) applied
data Theorem a = Theorem { axiom :: a , rulesThm :: [Rule a] , result :: a } deriving (Show)

-- | A prover system is consisted of a bunch of axioms and rules to apply between them
data TheoremProver a = TheoremProver { axioms :: [Theorem a] , rulesThmProver :: [Rule a] } deriving (Show)

-- | An axiom is just a theorem already proven
mkAxiom :: a -> Theorem a
mkAxiom a = Theorem a [] a

-- | Applies a single rule to a theorem
thmApplyRule :: Theorem a -> Rule a -> Theorem a
thmApplyRule theorem rule =
    Theorem
    (axiom theorem)
    (rulesThm theorem ++ [rule])
    ((function rule) (result theorem))

-- | Applies all prover's rules to a list of theorems
thmApplyRules :: TheoremProver a -> [Theorem a] -> [Theorem a]
thmApplyRules prover (thm:thms) = map (thmApplyRule thm) (rulesThmProver prover) ++ (thmApplyRules prover thms)
thmApplyRules _ _ = []

-- | Merge two list of proofs but skip duplicate proofs, giving the first argument priority
-- This is used to avoid circular results in the search tree
-- E.g. application of rules resulting in an earlier theorem/axiom
mergeProofs :: Eq a => [Theorem a] -> [Theorem a] -> [Theorem a]
mergeProofs p1 p2 = nubBy (\t1 t2 -> result t1 == result t2) p1 ++ p2

-- | Finds a proof by constructing a proof tree by iteratively applying theorem rules
findProofIter :: (Ord a, Eq a) => TheoremProver a -> a -> Int -> [Theorem a] -> Maybe (Theorem a)
findProofIter _ _ 0 _ = Nothing
findProofIter prover target depth foundProofs = case (find (\x -> target == result x) foundProofs) of
    Just prf -> Just prf
    Nothing  ->
        let theorems = thmApplyRules prover foundProofs
            proofsSet = fromList (map result foundProofs)
            theoremsSet = fromList (map result theorems) in
        if (union proofsSet theoremsSet) == proofsSet
        -- The case where no new theorems were produced, that is, A union B = A
        then Nothing
        -- Otherwise keep producing new proofs
        else findProofIter prover target (depth - 1) (mergeProofs foundProofs theorems)

-- | Find proof helper
findProof :: Ord a => TheoremProver a -> a -> Int -> Maybe (Theorem a)
findProof prover target depth = findProofIter prover target depth (axioms prover)