-----------------------------------------------------------------------------
-- Copyright 2016, Ideas project team. This file is distributed under the
-- terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Ideas.Common.CyclicTree
   ( -- * Data type
     CyclicTree
     -- * Constructor functions
   , node, node0, node1, node2, leaf, label
     -- * Querying
   , isNode, isLeaf, isLabel
     -- * Replace functions
   , replaceNode, replaceLeaf, replaceLabel, shrinkTree
     -- * Fold and algebra
   , fold, foldUnwind
   , CyclicTreeAlg, fNode, fLeaf, fLabel, fRec, fVar
   , emptyAlg, monoidAlg
   ) where

import Control.Applicative
import Control.Monad
import Data.List (intercalate)
import Ideas.Common.Classes
import Ideas.Common.Id
import Test.QuickCheck hiding (label)

--------------------------------------------------------------
-- Data type

data CyclicTree a b
   = Node a [CyclicTree a b]
   | Leaf b
   | Label Id (CyclicTree a b)
   | Rec Int (CyclicTree a b)
   | Var Int

instance (Show a, Show b) => Show (CyclicTree a b) where
   show = fold Alg
      { fNode  = \a xs -> show a ++ par xs
      , fLeaf  = show
      , fLabel = \l s -> show l ++ ":" ++ s
      , fRec   = \n s -> '#' : show n ++ "=" ++ s
      , fVar   = \n   -> '#' : show n
      }

instance BiFunctor CyclicTree where
   biMap f g = fold idAlg {fNode = Node . f, fLeaf = Leaf . g}

instance Functor (CyclicTree d) where
   fmap = mapSecond

instance Applicative (CyclicTree d) where
   pure    = leaf
   p <*> q = fold idAlg {fLeaf = (`fmap` q)} p

instance Monad (CyclicTree d) where
   return = leaf
   (>>=)  = flip replaceLeaf

instance Foldable (CyclicTree d) where
   foldMap f = fold monoidAlg {fLeaf = f}

instance Traversable (CyclicTree d) where
   traverse f = fold emptyAlg
      { fNode  = \a -> liftA (node a) . sequenceA
      , fLeaf  = liftA leaf . f
      , fLabel = liftA . label
      , fRec   = liftA . Rec
      , fVar   = pure . Var
      }

instance Fix (CyclicTree a b) where
   fix f = Rec n (f (Var n))
    where
      vs = vars (f (Var (-1)))
      n  = maximum (-1 : vs) + 1

instance (Arbitrary a, Arbitrary b) => Arbitrary (CyclicTree a b) where
   arbitrary = sized arbTree
   shrink    = shrinkTree

arbTree :: (Arbitrary a, Arbitrary b) => Int -> Gen (CyclicTree a b)
arbTree = rec 0
 where
   rec vi 0 = frequency $
      (3, liftM leaf arbitrary)
      : [ (1, elements (map Var [1..vi])) | vi > 0 ]
   rec vi n = frequency
      [ (3, liftM2 node arbitrary ms)
      , (2, rec vi 0)
      , (1, liftM2 label genId m)
      , (1, liftM (Rec (vi+1)) (rec (vi+1) (n `div` 2)))
      ]
    where
      m = rec vi (n `div` 2)
      genId = elements [ newId [c] | c <- ['A' .. 'Z']]
      ms = choose (0, 3) >>= \i -> replicateM i m

shrinkTree :: CyclicTree a b -> [CyclicTree a b]
shrinkTree tree =
   case tree of
      Node a ts -> ts ++ map (node a) (shrinkTrees ts)
      Label l t -> t : map (Label l) (shrinkTree t)
      Rec n t   -> map (Rec n) (shrinkTree t)
      _ -> []

-- shrink exactly one tree
shrinkTrees :: [CyclicTree a b] -> [[CyclicTree a b]]
shrinkTrees []    = []
shrinkTrees (t:ts) = map (:ts) (shrinkTree t) ++ map (t:) (shrinkTrees ts)

-- local helpers
par :: [String] -> String
par xs | null xs   = ""
       | otherwise = "(" ++ intercalate ", " xs ++ ")"

vars :: CyclicTree a b -> [Int]
vars = fold monoidAlg {fVar = return}

--------------------------------------------------------------
-- Constructor functions

node :: a -> [CyclicTree a b] -> CyclicTree a b
node = Node

node0 :: a -> CyclicTree a b
node0 a = node a []

node1 :: a -> CyclicTree a b -> CyclicTree a b
node1 a x = node a [x]

node2 :: a -> CyclicTree a b -> CyclicTree a b -> CyclicTree a b
node2 a x y = node a [x, y]

leaf :: b -> CyclicTree a b
leaf = Leaf

label :: IsId n => n -> CyclicTree a b -> CyclicTree a b
label = Label . newId

--------------------------------------------------------------
-- Querying

isNode :: CyclicTree a b -> Maybe (a, [CyclicTree a b])
isNode (Node a xs) = Just (a, xs)
isNode _ = Nothing

isLeaf :: CyclicTree a b -> Maybe b
isLeaf (Leaf b) = Just b
isLeaf _ = Nothing

isLabel :: CyclicTree a b -> Maybe (Id, CyclicTree a b)
isLabel (Label l t) = Just (l, t)
isLabel _ = Nothing

--------------------------------------------------------------
-- Replace functions

replaceNode :: (a -> [CyclicTree a b] -> CyclicTree a b) -> CyclicTree a b -> CyclicTree a b
replaceNode f = fold idAlg {fNode = f}

replaceLabel :: (Id -> CyclicTree a b -> CyclicTree a b) -> CyclicTree a b -> CyclicTree a b
replaceLabel f = fold idAlg {fLabel = f}

replaceLeaf :: (b -> CyclicTree a c) -> CyclicTree a b -> CyclicTree a c
replaceLeaf f = fold idAlg {fLeaf = f}

--------------------------------------------------------------
-- Fold and algebra

fold :: CyclicTreeAlg a b t -> CyclicTree a b -> t
fold alg = rec
 where
   rec (Node a ts) = fNode alg a (map rec ts)
   rec (Leaf b)    = fLeaf alg b
   rec (Label l t) = fLabel alg l (rec t)
   rec (Rec n t)   = fRec alg n (rec t)
   rec (Var n)     = fVar alg n

foldUnwind :: CyclicTreeAlg a b t -> CyclicTree a b -> t
foldUnwind alg = start . fold Alg
   { fNode  = \a fs sub -> fNode alg a (map ($ sub) fs)
   , fLeaf  = \b _      -> fLeaf alg b
   , fLabel = \l f sub  -> fLabel alg l (f sub)
   , fRec   = \n f sub  -> let this = f (extend n this sub)
                           in this
   , fVar   = \n sub    -> sub n
   }
 where
   start f = f (error "foldUnwind: unbound var")
   extend n a sub i
      | i == n    = a
      | otherwise = sub i

data CyclicTreeAlg a b t = Alg
   { fNode  :: a -> [t] -> t
   , fLeaf  :: b -> t
   , fLabel :: Id -> t -> t
   , fRec   :: Int -> t -> t
   , fVar   :: Int -> t
   }

idAlg :: CyclicTreeAlg a b (CyclicTree a b)
idAlg = Alg Node Leaf Label Rec Var

emptyAlg :: CyclicTreeAlg a b t
emptyAlg = let f = error "emptyAlg: uninitialized" in Alg f f f f f

monoidAlg :: Monoid m => CyclicTreeAlg a b m
monoidAlg = Alg (const mconcat) mempty (const id) (const id) mempty