module HyLo.Formula(Formula(..), nnf,
                    composeFold, composeFoldM, composeMap, composeMapM, onShape,
                    mapSig, freeVars, boundVars,
                    --
                    TestFormula, metap_read_Formula, unit_tests)

where

import Test.QuickCheck ( Arbitrary(..), oneof, sized, variant )
import HyLo.Test       ( UnitTest, runTest )

import Text.Show.Functions ()

import Control.Monad          ( liftM2 )
import Control.Monad.Identity ( runIdentity )
import Control.Applicative    ( (<$>) )

import Text.Read ( Read(..) )

import Text.ParserCombinators.ReadP ( string, skipSpaces, )
import Text.ParserCombinators.ReadPrec ( ReadPrec, (<++), choice,
                                         lift, prec, pfail, reset )

import qualified Data.List as List

import Data.Generics.PlateDirect

import HyLo.Signature ( HasSignature(..),
                        emptySignature, merge,
                        addNomToSig, addPropToSig, addRelToSig )

import HyLo.Signature.Simple ( NomSymbol, PropSymbol, RelSymbol )

data Formula n p r = Top
                   | Bot
                   | Prop p
                   | Nom  n
                   --
                   | Neg  (Formula n p r)
                   --
                   | (Formula n p r)  :&:   (Formula n p r)
                   | (Formula n p r)  :|:   (Formula n p r)
                   | (Formula n p r) :-->:  (Formula n p r)
                   | (Formula n p r) :<-->: (Formula n p r)
                   --
                   | Diam r (Formula n p r)
                   | Box  r (Formula n p r)
                   --
                   | At  n  (Formula n p r)
                   --
                   | A (Formula n p r)
                   | E (Formula n p r)
                   --
                   | D (Formula n p r)
                   | B (Formula n p r)
                   --
                   | Down n (Formula n p r)
                       deriving(Eq, Ord)

infixl 8 :&:, :|:
infix  7 :<-->:
infixr 7 :-->:

instance Uniplate (Formula n p r) where
    uniplate  Top           = plate Top
    uniplate  Bot           = plate Bot
    --
    uniplate (Prop p)       = plate Prop |- p
    uniplate (Nom  n)       = plate Nom  |- n
    --
    uniplate (Neg f)        = plate Neg |* f
    --
    uniplate (f1 :&: f2)    = plate (:&:)    |* f1 |* f2
    uniplate (f1 :|: f2)    = plate (:|:)    |* f1 |* f2
    uniplate (f1 :-->: f2)  = plate (:-->:)  |* f1 |* f2
    uniplate (f1 :<-->: f2) = plate (:<-->:) |* f1 |* f2
    --
    uniplate (Diam r f)     = plate Diam |- r |* f
    uniplate (Box r f)      = plate Box  |- r |* f
    --
    uniplate (At  n f)      = plate At  |- n |* f
    --
    uniplate (A f)          = plate A |* f
    uniplate (E f)          = plate E |* f
    --
    uniplate (D f)          = plate D |* f
    uniplate (B f)          = plate B |* f
    --
    uniplate (Down x f)     = plate Down |- x |* f

instance (Ord n, Ord p, Ord r) => HasSignature (Formula n p r) where
    type NomsOf  (Formula n p r) = n
    type PropsOf (Formula n p r) = p
    type RelsOf  (Formula n p r) = r
    --
    getSignature  Top           = emptySignature
    getSignature  Bot           = emptySignature
    --
    getSignature (Prop p)       = addPropToSig p emptySignature
    getSignature (Nom  n)       = addNomToSig  n emptySignature
    --
    getSignature (Neg f)        = getSignature f
    --
    getSignature (f1 :&: f2)    = merge (getSignature f1) (getSignature f2)
    getSignature (f1 :|: f2)    = merge (getSignature f1) (getSignature f2)
    getSignature (f1 :-->: f2)  = merge (getSignature f1) (getSignature f2)
    getSignature (f1 :<-->: f2) = merge (getSignature f1) (getSignature f2)
    --
    getSignature (Diam r f)     = addRelToSig r (getSignature f)
    getSignature (Box r f)      = addRelToSig r (getSignature f)
    --
    getSignature (At  n f)      = addNomToSig n (getSignature f)
    --
    getSignature (A f)          = getSignature f
    getSignature (E f)          = getSignature f
    --
    getSignature (D f)          = getSignature f
    getSignature (B f)          = getSignature f
    --
    getSignature (Down n f)     = addNomToSig n (getSignature f)

nnf :: Formula n p r -> Formula n p r
nnf f@Top                = f
nnf f@Bot                = f
nnf f@(Neg Top)          = f
nnf f@(Neg Bot)          = f
nnf f@Prop{}             = f
nnf f@Nom{}              = f
nnf f@(Neg Prop{})       = f
nnf f@(Neg Nom{})        = f
--
nnf   (Neg (Neg f))      = nnf f
nnf      (f1 :&: f2)     = (nnf       f1)           :&: (nnf        f2)
nnf (Neg (f1 :&: f2))    = (nnf $ Neg f1)           :|: (nnf $ Neg  f2)
nnf      (f1 :|: f2)     = (nnf       f1)           :|: (nnf        f2)
nnf (Neg (f1 :|: f2))    = (nnf $ Neg f1)           :&: (nnf $ Neg  f2)
nnf      (f1 :-->: f2)   = (nnf $ Neg f1)           :|: (nnf        f2)
nnf (Neg (f1 :-->: f2))  = (nnf       f1)           :&: (nnf $ Neg  f2)
nnf      (f1 :<-->: f2)  = (nnf $      f1 :-->: f2) :&: (nnf $      f2 :-->: f1)
nnf (Neg (f1 :<-->: f2)) = (nnf $ Neg (f1 :-->: f2)):|: (nnf $ Neg (f2 :-->: f1))
nnf      (Diam r f)      = Diam r $ nnf      f
nnf (Neg (Diam r f))     = Box  r $ nnf (Neg f)
nnf      (Box r f)       = Box  r $ nnf      f
nnf (Neg (Box r f))      = Diam r $ nnf (Neg f)
--
nnf      (At _ f@At{})    = nnf      f
nnf (Neg (At _ f@At{}))   = nnf (Neg f)
--
nnf      (At n f)        = At n $ nnf f
nnf (Neg (At n f))       = At n $ nnf (Neg f)
nnf      (A f)           = A (nnf f)
nnf (Neg (A f))          = E $ nnf (Neg f)
nnf      (E f)           = E (nnf f)
nnf (Neg (E f))          = A $ nnf (Neg f)
nnf      (D f)           = D (nnf f)
nnf (Neg (D f))          = B $ nnf (Neg f)
nnf      (B f)           = B (nnf f)
nnf (Neg (B f))          = D $ nnf (Neg f)
nnf      (Down x f)      = Down x $ nnf f
nnf (Neg (Down x f))     = Down x $ nnf (Neg f)

instance (Show n, Show p, Show r) => Show (Formula n p r) where
    showsPrec _ Top          = showString "true"
    showsPrec _ Bot          = showString "false"
    showsPrec _ (Prop prop)  = shows prop
    showsPrec _ (Nom  nom)   = shows nom
    --
    showsPrec _ (Neg f)      = showChar '-' . showsPrec 1 f
    --
    showsPrec p (l :&: r)    = showParen (p > 0 && p /= 2) $
                                  showsPrec 2 l     .
                                  showString " ^ "  .
                                  showsPrec 2 r
    showsPrec p (l :|: r)    = showParen (p > 0 && p /= 3) $
                                 showsPrec 3 l      .
                                 showString " v "   .
                                 showsPrec 3 r
    showsPrec p (l :-->: r)  = showParen (p > 0 && p /= 4) $
                                 showsPrec 5 l      .
                                 showString " --> "  .
                                 showsPrec 4 r
    showsPrec p (l :<-->: r) = showParen (p > 0 && p /= 6) $
                                 showsPrec 6 l      .
                                 showString " <--> " .
                                 showsPrec 6 r
    --
    showsPrec _ (Diam r f)   = showChar '<' . shows r . showChar '>' .
                                 showsPrec 1 f
    showsPrec _ (Box  r f)   = showChar '[' . shows r . showChar ']' .
                                 showsPrec 1 f
    --
    showsPrec _ (At  i f)    = shows i . showChar ':' . showsPrec 1 f
    --
    showsPrec _ (A f)        = showString "A " . showsPrec 1 f
    showsPrec _ (E f)        = showString "E " . showsPrec 1 f
    --
    showsPrec _ (D f)        = showString "D " . showsPrec 1 f
    showsPrec _ (B f)        = showString "B " . showsPrec 1 f
    --
    showsPrec _ (Down x f)   = showString "down " . shows x . showString " . " .
                                 showsPrec 1 f


instance (Read n, Read p, Read r) => Read (Formula n p r) where
    readPrec = choice [do t <- left; infixOps t,
                       do t <- paren readPrec; infixOps t]


left :: (Read n, Read p, Read r) => ReadPrec (Formula n p r)
left = choice [
           do token "true" ; return Top,
           do token "false"; return Bot,
           Prop <$> prop,
           Nom  <$> nom,
           prec 9 $ do token "-"; Neg <$> readPrec,
           prec 9 $ do i <- nom; token ":"; At  i <$> readPrec,
           prec 9 $ do token "<"; r <- readPrec; token ">"; Diam r <$> readPrec,
           prec 9 $ do token "["; r <- readPrec; token "]"; Box  r <$> readPrec,
           prec 9 $ do token "A"; A <$> readPrec,
           prec 9 $ do token "E"; E <$> readPrec,
           prec 9 $ do token "D"; D <$> readPrec,
           prec 9 $ do token "B"; B <$> readPrec,
           prec 9 $ do token "down"; x <- nom; token "."; Down x <$> readPrec
          ]
    where prop = do p <- readPrec; blanks; return p
          nom  = do i <- readPrec; blanks; return i

infixOps :: (Read n,Read p,Read r) => Formula n p r -> ReadPrec (Formula n p r)
infixOps f = first [onPrec 4 $
                      do token "^"   ; f' <- readPrec; infixOps (f  :&: f'),
                    onPrec 3 $
                      do token "v"   ; f' <- readPrec; infixOps (f  :|: f'),
                    onPrec 2 $
                      do token "-->" ; f' <- readPrec; infixOps (f :-->:f'),
                    onPrec 1 $
                      do token "<-->"; f' <- readPrec; infixOps (f:<-->:f'),
                    return f
                   ]

onPrec :: Int -> ReadPrec t -> ReadPrec t
onPrec p a  = do r <- first [prec 0 $ Just <$> a,
                             prec (p - 1) $ return Nothing,
                             prec p $ Just <$> a]
                 maybe pfail return r

first :: [ReadPrec t] -> ReadPrec t
first = foldr1 (<++)

token :: String -> ReadPrec ()
token s = blanks >> lift (string s) >> blanks

blanks :: ReadPrec ()
blanks = lift skipSpaces

paren :: ReadPrec a -> ReadPrec a
paren a = do token "("; r <- reset a; token ")"; return r


-- composeXX functions follow the idea from
-- "A pattern for almost compositional functions", Bringert and Ranta.
composeFold :: b
            -> (b -> b -> b)
            -> (Formula n p r -> b)
            -> (Formula n p r -> b)
composeFold zero combine g = \e -> case e of
    Neg f      -> g f
    l   :&:  r -> g l `combine` g r
    l   :|:  r -> g l `combine` g r
    l  :-->: r -> g l `combine` g r
    l :<-->: r -> g l `combine` g r
    Diam _ f   -> g f
    Box  _ f   -> g f
    At   _ f   -> g f
    A f        -> g f
    E f        -> g f
    D f        -> g f
    B f        -> g f
    Down _ f   -> g f
    _          -> zero

composeFoldM :: Monad m
             => m b
             -> (b -> b -> m b)
             -> (Formula n p r -> m b)
             -> (Formula n p r -> m b)
composeFoldM zero combine g = \e -> case e of
    Neg f      -> g f
    l   :&:  r -> do gl <- g l; gr <- g r; combine gl gr
    l   :|:  r -> do gl <- g l; gr <- g r; combine gl gr
    l  :-->: r -> do gl <- g l; gr <- g r; combine gl gr
    l :<-->: r -> do gl <- g l; gr <- g r; combine gl gr
    Diam _ f   -> g f
    Box  _ f   -> g f
    At   _ f   -> g f
    A f        -> g f
    E f        -> g f
    D f        -> g f
    B f        -> g f
    Down _ f   -> g f
    _          -> zero


composeMap :: (Formula n p r -> Formula n p r)
           -> (Formula n p r -> Formula n p r)
           -> (Formula n p r -> Formula n p r)
composeMap baseCase g = \e -> case e of
    Neg f      -> Neg (g f)
    l   :&:  r -> g l   :&:  g r
    l   :|:  r -> g l   :|:  g r
    l  :-->: r -> g l  :-->: g r
    l :<-->: r -> g l :<-->: g r
    Diam r f   -> Diam r (g f)
    Box  r f   -> Box  r (g f)
    At   i f   -> At  i (g f)
    A f        -> A (g f)
    E f        -> E (g f)
    D f        -> D (g f)
    B f        -> B (g f)
    Down x f   -> Down x (g f)
    f          -> baseCase f

composeMapM :: (Monad m, Functor m)
            => (Formula n p r -> m (Formula n p r))
            -> (Formula n p r -> m (Formula n p r))
            -> (Formula n p r -> m (Formula n p r))
composeMapM baseCase g = \e -> case e of
    Neg f      -> Neg <$> (g f)
    l   :&:  r -> liftM2 (:&:)    (g l) (g r)
    l   :|:  r -> liftM2 (:|:)    (g l) (g r)
    l  :-->: r -> liftM2 (:-->:)  (g l) (g r)
    l :<-->: r -> liftM2 (:<-->:) (g l) (g r)
    Diam r f   -> Diam r <$> (g f)
    Box  r f   -> Box  r <$> (g f)
    At   i f   -> At  i  <$> (g f)
    A f        -> A <$> (g f)
    E f        -> E <$> (g f)
    D f        -> D <$> (g f)
    B f        -> B <$> (g f)
    Down x f   -> Down x <$> (g f)
    f          -> baseCase f


onShape :: (n -> n')
        -> (p -> p')
        -> (r -> r')
        -> (Formula n p r -> Formula n' p' r')
        -> (Formula n p r -> Formula n' p' r')
onShape mn mp mr g = \e -> case e of
    Top        -> Top
    Bot        -> Bot
    Prop p     -> Prop (mp p)
    Nom  i     -> Nom  (mn i)
    Neg f      -> Neg (g f)
    l   :&:  r -> g l   :&:  g r
    l   :|:  r -> g l   :|:  g r
    l  :-->: r -> g l  :-->: g r
    l :<-->: r -> g l :<-->: g r
    Diam r f   -> Diam (mr r) (g f)
    Box  r f   -> Box  (mr r) (g f)
    At   i f   -> At  (mn i) (g f)
    A f        -> A (g f)
    E f        -> E (g f)
    D f        -> D (g f)
    B f        -> B (g f)
    Down x f   -> Down (mn x) (g f)


mapSig :: (n -> n')
       -> (p -> p')
       -> (r -> r')
       -> Formula n  p  r
       -> Formula n' p' r'
mapSig mn mp mr = onShape mn mp mr (mapSig mn mp mr)

freeVars :: Eq n => Formula n p r -> [n]
freeVars (Nom i)    = [i]
freeVars (At i f)   = [i] `List.union` freeVars f
freeVars (Down i f) = List.delete i (freeVars f)
freeVars f          = composeFold [] List.union freeVars f

boundVars :: Eq n => Formula n p r -> [n]
boundVars f = [i | Down i _ <- universe f]

---------------------------------------
-- QuickCheck stuff                   -
---------------------------------------

instance (Arbitrary n, Arbitrary p, Arbitrary r) => Arbitrary (Formula n p r)
  where
    arbitrary   = sized form
        where form    0 = oneof simple
              form    n = oneof (simple ++ complex n)
              simple    = [return Top,
                           return Bot,
                           Prop <$> arbitrary,
                           Nom  <$> arbitrary]
              complex n = [Neg  <$> (form $ n-1),
                           --
                           liftM2 (:&:)    (form $ n `div` 2)
                                           (form $ n `div` 2 + n `mod` 2),
                           --
                           liftM2 (:|:)    (form $ n `div` 2)
                                           (form $ n `div` 2 + n `mod` 2),
                           --
                           liftM2 (:-->:)  (form $ n `div` 2)
                                           (form $ n `div` 2 + n `mod` 2),
                           --
                           liftM2 (:<-->:) (form $ n `div` 2)
                                           (form $ n `div` 2 + n `mod` 2),
                           --
                           liftM2 Diam     arbitrary (form $ n - 1),
                           liftM2 Box      arbitrary (form $ n - 1),
                           liftM2 At       arbitrary (form $ n - 1),
                           liftM2 Down     arbitrary (form $ n - 1),
                           --
                           A <$> (form $ n -1),
                           E <$> (form $ n -1),
                           --
                           D <$> (form $ n -1),
                           B <$> (form $ n -1)]

    coarbitrary Top          = variant 0
    coarbitrary Bot          = variant 1
    coarbitrary (Prop p)     = variant 2  . coarbitrary p
    coarbitrary (Nom  n)     = variant 3  . coarbitrary n
    coarbitrary (Neg  f)     = variant 4  . coarbitrary f
    coarbitrary (l   :&:  r) = variant 5  . coarbitrary l . coarbitrary r
    coarbitrary (l   :|:  r) = variant 6  . coarbitrary l . coarbitrary r
    coarbitrary (l  :-->: r) = variant 7  . coarbitrary l . coarbitrary r
    coarbitrary (l :<-->: r) = variant 8  . coarbitrary l . coarbitrary r
    coarbitrary (Diam r f)   = variant 9  . coarbitrary r . coarbitrary f
    coarbitrary (Box  r f)   = variant 10 . coarbitrary r . coarbitrary f
    coarbitrary (At   i f)   = variant 11 . coarbitrary i . coarbitrary f
    coarbitrary (A f)        = variant 12 . coarbitrary f
    coarbitrary (E f)        = variant 13 . coarbitrary f
    coarbitrary (D f)        = variant 14 . coarbitrary f
    coarbitrary (B f)        = variant 15 . coarbitrary f
    coarbitrary (Down x f)   = variant 16 . coarbitrary x . coarbitrary f

metap_read_Formula :: (Show n, Read n, Eq n,
                       Show p, Read p, Eq p,
                       Show r, Read r, Eq r)
                   => Formula n p r -> Bool
metap_read_Formula f = (\fun -> (fun f) == (fun . read . show $ f)) show

type TestFormula = Formula NomSymbol PropSymbol RelSymbol

prop_read :: TestFormula  -> Bool
prop_read = metap_read_Formula

prop_composeMapId :: TestFormula -> Bool
prop_composeMapId f = keep f == f
    where keep = composeMap id keep

prop_composeMapMIdent :: (TestFormula -> TestFormula)
                      -> (TestFormula -> TestFormula)
                      -> TestFormula
                      -> Bool
prop_composeMapMIdent bc g f = composeMap bc g f ==
                                runIdentity (composeMapM bc' g' f)
    where bc' = return . bc
          g'  = return . g

prop_onShape :: TestFormula
             -> (NomSymbol  -> NomSymbol)
             -> (PropSymbol -> PropSymbol)
             -> (RelSymbol  -> RelSymbol)
             -> Bool
prop_onShape f mn mp mr = shape f == shape (trans f)
    where trans = onShape mn mp mr trans
          shape = mapSig (const ()) (const ()) (const ())

prop_composeFoldMIdent :: Int
                       -> (Int -> Int -> Int)
                       -> (TestFormula -> Int)
                       -> TestFormula
                       -> Bool
prop_composeFoldMIdent z c g f = composeFold z c g f ==
                                  runIdentity (composeFoldM z' c' g' f)
    where z' = return z
          c' = \a b -> return (c a b)
          g' = return . g


prop_mapSigId :: TestFormula -> Bool
prop_mapSigId f = mapSig id id id f == f

prop_uniplateId :: TestFormula -> Bool
prop_uniplateId f = f == (uncurry (flip ($)) . uniplateList) f

prop_uniplateRepl :: TestFormula -> TestFormula -> Bool
prop_uniplateRepl f g' =
    case uniplateList f of
        ([],  repl) -> (repl [] == f)
        (g:gs,repl) -> (repl (g':gs) == f) == (g == g')

unit_tests :: UnitTest
unit_tests = [("read/show",             runTest prop_read),
              ("composeMap id",         runTest prop_composeMapId),
              ("composeMapM Identity",  runTest prop_composeMapMIdent),
              ("composeFoldM Identity", runTest prop_composeFoldMIdent),
              ("onShape shape",         runTest prop_onShape),
              ("mapSigId id",           runTest prop_mapSigId),
              ("uniplate id",           runTest prop_uniplateId),
              ("uniplate replacement",  runTest prop_uniplateRepl)]