{-# LANGUAGE DefaultSignatures         #-}
{-# LANGUAGE DeriveFoldable            #-}
{-# LANGUAGE DeriveFunctor             #-}
{-# LANGUAGE DeriveTraversable         #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE MultiParamTypeClasses     #-}
{-# LANGUAGE RankNTypes                #-}
{-# LANGUAGE TypeFamilies              #-}
{-# LANGUAGE TypeOperators             #-}

module Control.Recursion
    ( -- * Typeclasses
      Base
    , Recursive (..)
    , Corecursive (..)
    -- * Types
    , Fix (..)
    , Mu (..)
    , Nu (..)
    , ListF (..)
    , NonEmptyF (..)
    -- * Recursion schemes
    , hylo
    , prepro
    , postpro
    , mutu
    , zygo
    , para
    , apo
    , hypo
    , elgot
    , coelgot
    , micro
    , meta
    , meta'
    , scolio
    , cata
    , ana
    -- * Mendler-style recursion schemes
    , mhisto
    , mcata
    -- * Monadic recursion schemes
    , cataM
    , anaM
    , hyloM
    , zygoM
    , zygoM'
    , scolioM
    , scolioM'
    , coelgotM
    , elgotM
    , paraM
    , mutuM
    , mutuM'
    , microM
    -- * Helper functions
    , lambek
    , colambek
    , refix
    ) where

import           Control.Arrow       ((&&&))
import           Control.Composition ((.*), (.**))
import           Control.Monad       ((<=<))
import           Data.Foldable       (toList)
import           Data.List.NonEmpty  (NonEmpty (..))
import qualified Data.List.NonEmpty  as NE
-- import           Data.Traversable    (Traversable (..))
import           GHC.Generics
import           Numeric.Natural     (Natural)

type family Base t :: * -> *

class (Functor (Base t)) => Recursive t where

    project :: t -> Base t t

    default project :: (Generic t, Generic (Base t t), HCoerce (Rep t) (Rep (Base t t))) => t -> Base t t
    project = to . hcoerce . from

class (Functor (Base t)) => Corecursive t where

    embed :: Base t t -> t

    default embed :: (Generic t, Generic (Base t t), HCoerce (Rep (Base t t)) (Rep t)) => Base t t -> t
    embed = to . hcoerce . from

-- | Base functor for a list of type @[a]@.
data ListF a b = Cons a b
               | Nil
               deriving (Functor, Foldable, Traversable)

data NonEmptyF a b = NonEmptyF a (Maybe b)
    deriving (Functor, Foldable, Traversable)

newtype Fix f = Fix { unFix :: f (Fix f) }

-- Ν, Μ
data Nu f = forall a. Nu (a -> f a) a

newtype Mu f = Mu (forall a. (f a -> a) -> a)

type instance Base (Fix f) = f

type instance Base (Fix f) = f

type instance Base (Mu f) = f

type instance Base (Nu f) = f

type instance Base Natural = Maybe

type instance Base [a] = ListF a

type instance Base (NonEmpty a) = NonEmptyF a

instance Recursive Natural where
    project 0 = Nothing
    project n = Just (n-1)

instance Corecursive Natural where
    embed Nothing  = 0
    embed (Just n) = n+1

instance Functor f => Recursive (Nu f) where
    project (Nu f a) = Nu f <$> f a

instance Functor f => Corecursive (Nu f) where
    embed = colambek

instance Functor f => Recursive (Mu f) where
    project = lambek

instance Functor f => Corecursive (Mu f) where
    embed μ = Mu (\f -> f (fmap (cata f) μ))

instance Recursive [a] where
    project []     = Nil
    project (x:xs) = Cons x xs

instance Corecursive [a] where
    embed Nil         = []
    embed (Cons x xs) = x : xs

instance Recursive (NonEmpty a) where
    project (x :| []) = NonEmptyF x Nothing
    project (x :| xs) = NonEmptyF x (Just (NE.fromList xs))

instance Corecursive (NonEmpty a) where
    embed (NonEmptyF x Nothing)   = x :| []
    embed (NonEmptyF x (Just xs)) = x :| toList xs

instance Functor f => Recursive (Fix f) where
    project = unFix

instance Functor f => Corecursive (Fix f) where
    embed = Fix

eitherM :: Monad m => (a -> m c) -> (b -> m c) -> m (Either a b) -> m c
eitherM l r = (either l r =<<)

-- | Catamorphism. Folds a structure. (see [here](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.41.125&rep=rep1&type=pdf))
--
-- >>> :{
-- let {
--   sum' :: (Num a) => [a] -> a ;
--   sum' = cata a
--     where
--       a Nil         = 0
--       a (Cons x xs) = x + xs
-- }
-- :}
--
-- >>> sum' [1..100]
-- 5050
cata :: (Recursive t) => (Base t a -> a) -> t -> a
cata f = c where c = f . fmap c . project
{-# NOINLINE [0] cata #-}

{-# RULES
  "cata/Mu" forall f (g :: forall a. (f a -> a) -> a). cata f (Mu g) = g f;
     #-}

-- | Anamorphism, meant to build up a structure recursively.
ana :: (Corecursive t) => (a -> Base t a) -> a -> t
ana g = a where a = embed . fmap a . g
{-# NOINLINE [0] ana #-}

{-# RULES
   "ana/Nu" forall (f :: a -> f a). ana f = Nu f;
      #-}

-- | Hylomorphism; fold a structure while buildiung it up.
hylo :: Functor f => (f b -> b) -> (a -> f a) -> a -> b
hylo f g = h where h = f . fmap h . g
{-# NOINLINE [0] hylo #-}

{-# RULES
  "ana/cata/hylo"  forall f g x. cata f (ana g x) = hylo f g x;
     #-}

zipA :: (Applicative f) => f a -> f b -> f (a, b)
zipA x y = (,) <$> x <*> y

zipM :: (Monad m) => m a -> m b -> m (a, b)
zipM x y = do { a <- y; b <- x; pure (b, a) }

cataM :: (Recursive t, Traversable (Base t), Monad m) => (Base t a -> m a) -> t -> m a
cataM f = c where c = f <=< (traverse c . project)

paraM :: (Recursive t, Corecursive t, Traversable (Base t), Monad m) => (Base t (t, a) -> m a) -> t -> m a
paraM f = fmap snd . cataM (\x -> (,) (embed (fmap fst x)) <$> f x)

zygoM :: (Recursive t, Traversable (Base t), Monad m) => (Base t b -> m b) -> (Base t (b, a) -> m a) -> t -> m a
zygoM f g = fmap snd . cataM (\x -> zipA (f (fmap fst x)) (g x))

zygoM' :: (Recursive t, Traversable (Base t), Monad m) => (Base t b -> m b) -> (Base t (b, a) -> m a) -> t -> m a
zygoM' f g = fmap snd . cataM (\x -> zipM (f (fmap fst x)) (g x))

scolioM :: (Recursive t, Traversable (Base t), Monad m) => (Base t (t, a) -> m t) -> (Base t (t, a) -> m a) -> t -> m a
scolioM f g = fmap snd . cataM (\x -> zipA (f x) (g x))

scolioM' :: (Recursive t, Traversable (Base t), Monad m) => (Base t (t, a) -> m t) -> (Base t (t, a) -> m a) -> t -> m a
scolioM' f g = fmap snd . cataM (\x -> zipM (f x) (g x))

anaM :: (Corecursive t, Traversable (Base t), Monad m) => (a -> m (Base t a)) -> a -> m t
anaM f = a where a = (fmap embed . traverse a) <=< f

hyloM :: (Traversable f, Monad m) => (f b -> m b) -> (a -> m (f a)) -> a -> m b
hyloM f g = h where h = f <=< traverse h <=< g

elgotM :: (Traversable f, Monad m) => (f a -> m a) -> (b -> m (Either a (f b))) -> b -> m a
elgotM φ ψ = h where h = eitherM pure (φ <=< traverse h) . ψ

microM :: (Corecursive a, Traversable (Base a), Monad m) => (b -> m (Either a (Base a b))) -> b -> m a
microM = elgotM (pure . embed)

coelgotM :: (Traversable f, Monad m) => ((a, f b) -> m b) -> (a -> m (f a)) -> a -> m b
coelgotM φ ψ = h where h = φ <=< (\x -> (,) x <$> (traverse h <=< ψ) x)

lambek :: (Recursive t, Corecursive t) => (t -> Base t t)
lambek = cata (fmap embed)

colambek :: (Recursive t, Corecursive t) => (Base t t -> t)
colambek = ana (fmap project)

-- | Prepromorphism. Fold a structure while applying a natural transformation at each step.
prepro :: (Recursive t, Corecursive t) => (Base t t -> Base t t) -> (Base t a -> a) -> t -> a
prepro e f = c
    where c = f . fmap (c . cata (embed . e)) . project

-- | Postpromorphism. Build up a structure, applying a natural transformation along the way.
postpro :: (Recursive t, Corecursive t) => (Base t t -> Base t t) -> (a -> Base t a) -> a -> t
postpro e g = a'
    where a' = embed . fmap (ana (e . project) . a') . g

-- | A mutumorphism.
--
-- >>> :{
-- let {
--   even' :: Natural -> Bool ;
--   even' = mutu o e
--     where
--       o :: Maybe (Bool, Bool) -> Bool
--       o Nothing = False
--       o (Just (_, b)) = b
--       e :: Maybe (Bool, Bool) -> Bool
--       e Nothing = True
--       e (Just (_, b)) = b
-- }
-- :}
--
-- >>> even' 10
-- True
mutu :: (Recursive t) => (Base t (a, a) -> a) -> (Base t (a, a) -> a) -> t -> a
mutu f g = snd . cata (f &&& g)

mutuM :: (Recursive t, Traversable (Base t), Monad m) => (Base t (a, a) -> m a) -> (Base t (a, a) -> m a) -> t -> m a
mutuM f g = h where h = fmap snd . cataM (\x -> zipA (f x) (g x))

mutuM' :: (Recursive t, Traversable (Base t), Monad m) => (Base t (a, a) -> m a) -> (Base t (a, a) -> m a) -> t -> m a
mutuM' f g = h where h = fmap snd . cataM (\x -> zipM (f x) (g x))

-- | Catamorphism collapsing along two data types simultaneously.
scolio :: (Recursive t) => (Base t (a, t) -> a) -> (Base t (a, t) -> t) -> t -> a
scolio = fst .** (cata .* (&&&))

-- | Zygomorphism (see [here](http://www.iis.sinica.edu.tw/~scm/pub/mds.pdf) for a neat example)
--
-- >>> :set -XTypeFamilies
-- >>> import Data.Char (toUpper, toLower)
-- >>> :{
-- let {
--   spongebobZygo :: String -> String ;
--   spongebobZygo = zygo a pa
--     where
--       a :: ListF Char Bool -> Bool
--       a Nil          = False
--       a (Cons ' ' b) = b
--       a (Cons ',' b) = b
--       a (Cons _ b)   = not b
--       pa :: ListF Char (Bool, String) -> String
--       pa Nil                 = ""
--       pa (Cons c (True, s))  = toUpper c : s
--       pa (Cons c (False, s)) = toLower c : s
-- }
-- :}
--
-- >>> spongebobZygo "Hello, World"
-- "HeLlO, wOrLd"
zygo :: (Recursive t) => (Base t b -> b) -> (Base t (b, a) -> a) -> t -> a
zygo f g = snd . cata (\x -> (f (fmap fst x), g x))

-- | Paramorphism
--
-- >>> :{
-- let {
--   dedup :: Eq a => [a] -> [a] ;
--   dedup = para pa
--     where
--       pa :: Eq a => ListF a ([a], [a]) -> [a]
--       pa Nil = []
--       pa (Cons x (past, xs)) = if x `elem` past then xs else x:xs
-- }
-- :}
--
-- >>> dedup [1,1,2]
-- [1,2]
para :: (Recursive t, Corecursive t) => (Base t (t, a) -> a) -> t -> a
para f = snd . cata (\x -> (embed (fmap fst x), f x))

-- | Gibbons' metamorphism. Tear down a structure, transform it, and then build up a new structure
meta :: (Corecursive t', Recursive t) => (a -> Base t' a) -> (b -> a) -> (Base t b -> b) -> t -> t'
meta f e g = ana f . e . cata g

-- | Erwig's metamorphism. Essentially a hylomorphism with a natural
-- transformation in between. This allows us to use more than one functor in a
-- hylomorphism.
meta' :: (Functor g) => (f a -> a) -> (forall c. g c -> f c) -> (b -> g b) -> b -> a
meta' h e k = g
    where g = h . e . fmap g . k

-- | Mendler's catamorphism
--
-- >>> import Data.Word (Word64)
-- >>> let asFix = cata Fix
-- >>> let base = (2 ^ (64 :: Int)) :: Integer
-- >>> :{
-- let {
--   wordListToInteger :: [Word64] -> Integer ;
--   wordListToInteger = mcata ma . asFix
--     where
--       ma f (Cons x xs) = fromIntegral x + base * f xs
--       ma _ Nil         = 0
-- }
-- :}
--
-- >>> wordListToInteger [1,0,1]
-- 340282366920938463463374607431768211457
mcata :: (forall y. ((y -> c) -> f y -> c)) -> Fix f -> c
mcata ψ = mc where mc = ψ mc . unFix

-- | Mendler's histomorφsm
mhisto :: (forall y. ((y -> c) -> (y -> f y) -> f y -> c)) -> Fix f -> c
mhisto ψ = mh where mh = ψ mh unFix . unFix

-- | Elgot algebra (see [this paper](https://arxiv.org/abs/cs/0609040))
--
-- >>> :{
-- let {
--   collatzLength :: Integer -> Integer ;
--   collatzLength = elgot a pc
--     where
--       pc :: Integer -> Either Integer (ListF Integer Integer)
--       pc 1 = Left 1
--       pc n
--         | n `mod` 2 == 0 = Right $ Cons n (div n 2)
--         | otherwise = Right $ Cons n (3 * n + 1)
--       a :: ListF Integer Integer -> Integer
--       a Nil        = 0
--       a (Cons _ x) = x + 1
-- }
-- :}
--
-- >>> collatzLength 2223
-- 183
elgot :: Functor f => (f a -> a) -> (b -> Either a (f b)) -> b -> a
elgot φ ψ = h where h = either id (φ . fmap h) . ψ

-- | Anamorphism allowing shortcuts. Compare 'apo'
micro :: (Corecursive a) => (b -> Either a (Base a b)) -> b -> a
micro = elgot embed

-- | Co-(Elgot algebra)
--
-- >>> import Data.Word (Word64)
-- >>> let base = (2 ^ (64 :: Int)) :: Integer
-- >>> :{
-- let {
--   integerToWordList :: Integer -> [Word64] ;
--   integerToWordList = coelgot pa c
--     where
--       c i = Cons (fromIntegral (i `mod` (2 ^ (64 :: Int)))) (i `div` (2 ^ (64 :: Int)))
--       pa (i, ws) | i < 2 ^ (64 :: Int) = [fromIntegral i]
--                  | otherwise = embed ws
-- }
-- :}
--
-- >>> integerToWordList 340282366920938463463374607431768211457
-- [1,0,1]
coelgot :: Functor f => ((a, f b) -> b) -> (a -> f a) -> a -> b
coelgot φ ψ = h where h = φ . (\x -> (x, fmap h . ψ $ x))

-- | Apomorphism. Compare 'micro'.
--
-- >>> :{
-- let {
--   isInteger :: (RealFrac a) => a -> Bool ;
--   isInteger = idem (realToFrac . floor)
--     where
--       idem f x = x == f x
-- }
-- :}
--
-- >>> :{
-- let {
--   continuedFraction :: (RealFrac a, Integral b) => a -> [b] ;
--   continuedFraction = apo pc
--     where
--       pc x
--         | isInteger x = go $ Left []
--         | otherwise   = go $ Right alpha
--           where
--             alpha = 1 / (x - realToFrac (floor x))
--             go    = Cons (floor x)
-- }
-- :}
--
-- >>> take 13 $ continuedFraction pi
-- [3,7,15,1,292,1,1,1,2,1,3,1,14]
--
-- >>> :{
-- let {
--   integerToWordList :: Integral a => a -> a -> [a] ;
--   integerToWordList base = apo pc
--     where
--       pc i | i < base  = Cons (fromIntegral i) (Left [])
--            | otherwise = Cons (fromIntegral (i `mod` base)) (Right (i `div` base))
-- }
-- :}
--
-- >>> integerToWordList 2 5
-- [1,0,1]
apo :: (Corecursive t) => (a -> Base t (Either t a)) -> a -> t
apo ψ = a where a = embed . fmap (either id a) . ψ

-- | Hypomorphism.
--
-- @since 2.2.3.0
hypo :: (Recursive t, Corecursive t) => (a -> Base t (Either t a)) -> (Base t (t, b) -> b) -> a -> b
hypo φ ψ = para ψ . apo φ

refix :: (Recursive s, Corecursive t, Base s ~ Base t) => s -> t
refix = cata embed

-- taken from http://hackage.haskell.org/package/recursion-schemes/docs/src/Data.Functor.Foldable.html#gcoerce
class HCoerce f g where
    hcoerce :: f a -> g a

instance HCoerce f g => HCoerce (M1 i c f) (M1 i c' g) where
    hcoerce (M1 x) = M1 (hcoerce x)

instance HCoerce (K1 i c) (K1 j c) where
    hcoerce = K1 . unK1

instance HCoerce U1 U1 where
    hcoerce = id

instance HCoerce V1 V1 where
    hcoerce = id

instance (HCoerce f g, HCoerce f' g') => HCoerce (f :*: f') (g :*: g') where
    hcoerce (x :*: y) = hcoerce x :*: hcoerce y

instance (HCoerce f g, HCoerce f' g') => HCoerce (f :+: f') (g :+: g') where
    hcoerce (L1 x) = L1 (hcoerce x)
    hcoerce (R1 x) = R1 (hcoerce x)