{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}

module Data.PolyTree where

import Control.Applicative ( Applicative(liftA2) )
import Control.Lens
    ( iso,
      _Left,
      _Right,
      Plated(..),
      Iso',
      Lens,
      Lens',
      Prism',
      Traversal,
      Traversal' )
import Data.Bifoldable ( Bifoldable(bifoldMap) )
import Data.Bifunctor ( Bifunctor(bimap) )
import Data.Bitraversable ( Bitraversable(..) )
import Data.Functor.Apply ( Apply(liftF2, (<.>)) )
import Data.Functor.Classes
    ( showsBinaryWith,
      Eq1(..),
      Eq2(..),
      Ord1(..),
      Ord2(..),
      Show1(liftShowsPrec),
      Show2(..) )
import Data.Functor.Identity ( Identity(..) )
import Data.List.NonEmpty ( NonEmpty(..), nonEmpty, toList )
import Data.Semigroup.Bifoldable ( Bifoldable1(bifoldMap1) )
import Data.Semigroup.Bitraversable ( Bitraversable1(bitraverse1) )
import Data.Semigroup.Foldable ( Foldable1(foldMap1) )
import Data.Semigroup.Traversable ( Traversable1(traverse1) )
import qualified Data.Tree as Tree

-- $setup
-- >>> import Control.Lens

type TreeForest f a b =
  f (Either b (Tree f a b))

type TreeForest' f a =
  TreeForest f a a

data Tree f a b =
  Tree a (TreeForest f a b)

type Tree' f a =
  Tree f a a

type TreeList a b =
  Tree [] a b

type TreeList' a =
  TreeList a a

type Tree1 a b =
  Tree Identity a b

type Tree1' a =
  Tree1 a a

instance Eq1 f => Eq2 (Tree f) where
  liftEq2 f g (Tree a t1) (Tree b t2) =
    f a b &&
    liftEq (liftEq2 g (liftEq2 f g)) t1 t2

instance Ord1 f => Ord2 (Tree f) where
  liftCompare2 f g (Tree a t1) (Tree b t2) =
    f a b <>
    liftCompare (liftCompare2 g (liftCompare2 f g)) t1 t2

instance Show1 f => Show2 (Tree f) where
  liftShowsPrec2 spA slA spB slB d (Tree a t) =
    let spT =
          liftShowsPrec2 spA slA spB slB
        slT =
          liftShowList2 spA slA spB slB
    in  showsBinaryWith
          spA
          (liftShowsPrec
            (liftShowsPrec2 spB slB spT slT)
            (liftShowList2 spB slB spT slT))
          "Tree"
          d
          a
          t

instance (Eq a, Eq1 f) => Eq1 (Tree f a) where
  liftEq =
    liftEq2 (==)

instance (Ord a, Ord1 f) => Ord1 (Tree f a) where
  liftCompare =
    liftCompare2 compare

instance (Show a, Show1 f) => Show1 (Tree f a) where
  liftShowsPrec =
    liftShowsPrec2 showsPrec showList

instance (Eq a, Eq1 f, Eq b) => Eq (Tree f a b) where
  (==) =
    liftEq (==)

instance (Ord a, Ord1 f, Ord b) => Ord (Tree f a b) where
  compare =
    liftCompare compare

instance (Show a, Show1 f, Show b) => Show (Tree f a b) where
  showsPrec =
    liftShowsPrec showsPrec shows

instance Functor f => Bifunctor (Tree f) where
  bimap f g (Tree a t) =
    Tree (f a) (fmap (bimap g (bimap f g)) t)

instance Functor f => Functor (Tree f a) where
  fmap =
    bimap id

instance (Apply f, Semigroup a) => Apply (Tree f a) where
  Tree a1 t1 <.> Tree a2 t2 =
    let combine (Left f) (Left x) =
          Left (f x)
        combine (Left f) (Right tx) =
          Right (fmap f tx)
        combine (Right tf) (Left x) =
          Right (fmap ($ x) tf)
        combine (Right tf) (Right tx) =
          Right (tf <.> tx)
    in  Tree (a1 <> a2) (liftF2 combine t1 t2)

-- |
--
-- >>> Tree "a" [] <*> Tree "b" [] :: TreeList String String
-- Tree "ab" []
--
-- >>> Tree "a" [Left Prelude.reverse] <*> Tree "b" [Left "xyz"] :: TreeList String String
-- Tree "ab" [Left "zyx"]
--
-- >>> Tree "a" [Left Prelude.reverse] <*> Tree "b" [Left "xyz", makeChild "c" [Left "pqr"], makeChild "d" [Left "mno"]] :: TreeList String String
-- Tree "ab" [Left "zyx",Right (Tree "c" [Left "rqp"]),Right (Tree "d" [Left "onm"])]
instance (Applicative f, Monoid a) => Applicative (Tree f a) where
  pure b =
    Tree mempty (pure (Left b))
  Tree a1 t1 <*> Tree a2 t2 =
    let combine (Left f) (Left x) =
          Left (f x)
        combine (Left f) (Right tx) =
          Right (fmap f tx)
        combine (Right tf) (Left x) =
          Right (fmap ($ x) tf)
        combine (Right tf) (Right tx) =
          Right (tf <*> tx)
    in  Tree (a1 <> a2) (liftA2 combine t1 t2)

instance Foldable f => Bifoldable (Tree f) where
  bifoldMap f g (Tree a t) =
    f a <> foldMap (either g (bifoldMap f g)) t

instance Foldable1 f => Bifoldable1 (Tree f) where
  bifoldMap1 f g (Tree a t) =
    f a <> foldMap1 (either g (bifoldMap1 f g)) t

instance Foldable f => Foldable (Tree f a) where
  foldMap f (Tree _ t) =
    foldMap (either f (foldMap f)) t

instance Foldable1 f => Foldable1 (Tree f a) where
  foldMap1 f (Tree _ t) =
    foldMap1 (either f (foldMap1 f)) t

instance Traversable f => Bitraversable (Tree f) where
  bitraverse f g (Tree a t) =
    Tree <$> f a <*> traverse (either (fmap Left . g) (fmap Right . bitraverse f g)) t

instance Traversable1 f => Bitraversable1 (Tree f) where
  bitraverse1 f g (Tree a t) =
    Tree <$> f a <.> traverse1 (either (fmap Left . g) (fmap Right . bitraverse1 f g)) t

instance Traversable f => Traversable (Tree f a) where
  traverse f (Tree a t) =
    Tree a <$> traverse (either (fmap Left . f) (fmap Right . traverse f)) t

instance Traversable1 f => Traversable1 (Tree f a) where
  traverse1 f (Tree a t) =
    Tree a <$> traverse1 (either (fmap Left . f) (fmap Right . traverse1 f)) t

instance Traversable f => Plated (Tree f a b) where
  plate f (Tree a t) =
    Tree a <$> traverse (either (pure . Left) (fmap Right . f)) t

treeForest' ::
  Lens
    (Tree f a b)
    (Tree f' a b')
    (TreeForest f a b)
    (TreeForest f' a b')
treeForest' f (Tree a t) =
  fmap (Tree a) (f t)

treeSubForest ::
  Traversable f =>
  Traversal
    (Tree f a b)
    (Tree f a b')
    (Either b (Tree f a b))
    (Either b' (Tree f a b'))
treeSubForest =
  treeForest' . traverse

treeLeaves ::
  Traversable f =>
  Traversal'
    (Tree f a b)
    b
treeLeaves =
  treeSubForest . _Left

treeForestChildren ::
  Traversable f =>
  Traversal'
    (Tree f a b)
    (Tree f a b)
treeForestChildren =
  treeSubForest . _Right

class HasTree x f a b | x -> f a b where
  tree ::
    Lens' x (Tree f a b)
  {-# INLINE treeLabel #-}
  treeLabel ::
    Lens' x a
  treeLabel =
    tree . treeLabel
  {-# INLINE treeForest #-}
  treeForest ::
    Lens' x (TreeForest f a b)
  treeForest =
    tree . treeForest

instance HasTree (Tree f a b) f a b where
  tree =
    id
  {-# INLINE treeLabel #-}
  treeLabel f (Tree a t) =
    fmap (`Tree` t) (f a)
  {-# INLINE treeForest #-}
  treeForest f (Tree a t) =
    fmap (Tree a) (f t)

class AsTree x f a b | x -> f a b where
  _Tree ::
    Prism' x (Tree f a b)

instance AsTree (Tree f a b) f a b where
  _Tree =
    id

-- |
--
-- >>> dfs (Tree 1 [])
-- Left 1 :| []
--
-- >>> dfs (Tree 1 [Left 2])
-- Left 1 :| [Right 2]
--
-- >>> dfs (Tree 1 [Left 2, makeChild 3 []])
-- Left 1 :| [Right 2,Left 3]
--
-- >>> dfs (Tree 1 [Left 2, makeChild 3 [], Left 4])
-- Left 1 :| [Right 2,Left 3,Right 4]
--
-- >>> dfs (Tree 1 [Left 2, makeChild 3 [Left 5], Left 4])
-- Left 1 :| [Right 2,Left 3,Right 5,Right 4]
--
-- >>> dfs (Tree 1 [Left 2, makeChild 3 [Left 5], Left 4, makeChild 6 []])
-- Left 1 :| [Right 2,Left 3,Right 5,Right 4,Left 6]
dfs ::
  Foldable f =>
  Tree f a b ->
  NonEmpty (Either a b)
dfs (Tree a t) =
  Left a :| foldMap (either (\b -> [Right b]) (toList . dfs)) t

-- |
--
-- >>> bfs (Tree 1 [])
-- Left 1 :| []
--
-- >>> bfs (Tree 1 [Left 2])
-- Left 1 :| [Right 2]
--
-- >>> bfs (Tree 1 [Left 2, makeChild 3 []])
-- Left 1 :| [Right 2,Left 3]
--
-- >>> bfs (Tree 1 [Left 2, makeChild 3 [], Left 4])
-- Left 1 :| [Right 2,Right 4,Left 3]
--
-- >>> bfs (Tree 1 [Left 2, makeChild 3 [Left 5], Left 4])
-- Left 1 :| [Right 2,Right 4,Left 3,Right 5]
--
-- >>> bfs (Tree 1 [Left 2, makeChild 3 [Left 5], Left 4, makeChild 6 []])
-- Left 1 :| [Right 2,Right 4,Left 3,Right 5,Left 6]
bfs ::
  Foldable f =>
  Tree f a b
  -> NonEmpty (Either a b)
bfs root =
  let go (Tree a t :| rest) =
        let (leaves, c) =
              foldMap (either (\b -> ([Right b], [])) (\tr -> ([], [tr]))) t
        in  case nonEmpty (rest <> c) of
              Nothing -> Left a :| leaves
              Just q  -> Left a :| (leaves <> toList (go q))
  in  go (root :| [])

makeChild ::
  a
  -> TreeForest f a b
  -> Either x (Tree f a b)
makeChild a t =
  Right (Tree a t)

makeLeaves ::
  Functor f =>
  a
  -> f b
  -> Tree f a b
makeLeaves a bs =
  Tree a (Left <$> bs)

makeChildren ::
  Functor f =>
  a
  -> f (Tree f a b)
  -> Tree f a b
makeChildren a cs =
  Tree a (Right <$> cs)

baseTree ::
  Iso' (TreeList' a) (Tree.Tree a)
baseTree =
  iso
    (
      let go (Tree a t) =
            Tree.Node a (fmap (either pure go) t)
      in  go)
    (
      let perNode (Tree.Node a []) =
            Left a
          perNode tr@(Tree.Node _ (_:_)) =
            Right tr
          go (Tree.Node a t) =
            Tree a (fmap (fmap go . perNode) t)
      in  go)
