{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Data.PolyTree where

import Control.Applicative ( Applicative(liftA2) )
import Control.Lens
    ( view,
      iso,
      _Left,
      _Right,
      _Wrapped,
      Plated(..),
      Iso',
      Lens,
      Lens',
      Prism',
      Traversal,
      Traversal',
      Rewrapped,
      Wrapped(..) )
import Data.Bifoldable ( Bifoldable(bifoldMap) )
import Data.Bifunctor ( Bifunctor(bimap) )
import Data.Bitraversable ( Bitraversable(..) )
import Data.Functor.Apply ( Apply(liftF2, (<.>)) )
import Data.Functor.Classes
    ( showsBinaryWith,
      showsUnaryWith,
      Eq1(..),
      Eq2(..),
      Ord1(..),
      Ord2(..),
      Show1(liftShowsPrec),
      Show2(..) )
import Data.Functor.Identity ( Identity(..) )
import Data.List.NonEmpty ( NonEmpty(..), nonEmpty, toList )
import Data.Semigroup.Bifoldable ( Bifoldable1(bifoldMap1) )
import Data.Semigroup.Bitraversable ( Bitraversable1(bitraverse1) )
import Data.Semigroup.Foldable ( Foldable1(foldMap1) )
import Data.Semigroup.Traversable ( Traversable1(traverse1) )
import qualified Data.Tree as Tree

-- $setup
-- >>> import Control.Lens

newtype TreeForest f a b =
  TreeForest (f (Either b (Tree f a b)))

type TreeForest' f a =
  TreeForest f a a

instance (TreeForest f_a4KG a_a4KH b_a4KI ~ t_a4KF) =>
  Rewrapped (TreeForest f_a3dh a_a3di b_a3dj) t_a4KF

instance Wrapped (TreeForest f a b) where
  type Unwrapped (TreeForest f a b) =
    f (Either b (Tree f a b))
  _Wrapped' =
    iso (\(TreeForest x) -> x) TreeForest

instance Eq1 f => Eq2 (TreeForest f) where
  liftEq2 f g (TreeForest x1) (TreeForest x2) =
    liftEq (liftEq2 g (liftEq2 f g)) x1 x2

instance Ord1 f => Ord2 (TreeForest f) where
  liftCompare2 f g (TreeForest x1) (TreeForest x2) =
    liftCompare (liftCompare2 g (liftCompare2 f g)) x1 x2

instance Show1 f => Show2 (TreeForest f) where
  liftShowsPrec2 spA slA spB slB d (TreeForest x) =
    let spT =
          liftShowsPrec2 spA slA spB slB
        slT =
          liftShowList2 spA slA spB slB
    in  showsUnaryWith
          (liftShowsPrec
                (liftShowsPrec2 spB slB spT slT)
                (liftShowList2 spB slB spT slT))
          "TreeForest"
          d
          x

instance (Eq a, Eq1 f) => Eq1 (TreeForest f a) where
  liftEq =
    liftEq2 (==)

instance (Ord a, Ord1 f) => Ord1 (TreeForest f a) where
  liftCompare =
    liftCompare2 compare

instance (Show a, Show1 f) => Show1 (TreeForest f a) where
  liftShowsPrec =
    liftShowsPrec2 showsPrec showList

instance (Eq a, Eq1 f, Eq b) => Eq (TreeForest f a b) where
  (==) =
    liftEq (==)

instance (Ord a, Ord1 f, Ord b) => Ord (TreeForest f a b) where
  compare =
    liftCompare compare

instance (Show a, Show1 f, Show b) => Show (TreeForest f a b) where
  showsPrec =
    liftShowsPrec showsPrec shows

instance Functor f => Bifunctor (TreeForest f) where
  bimap f g (TreeForest x) =
    TreeForest (fmap (bimap g (bimap f g)) x)

instance Functor f => Functor (TreeForest f a) where
  fmap =
    bimap id

instance (Apply f, Semigroup a) => Apply (TreeForest f a) where
  TreeForest x1 <.> TreeForest x2 =
    let combine (Left f) (Left x) =
          Left (f x)
        combine (Left f) (Right tx) =
          Right (fmap f tx)
        combine (Right tf) (Left x) =
          Right (fmap ($ x) tf)
        combine (Right tf) (Right tx) =
          Right (tf <.> tx)
    in  TreeForest (liftF2 combine x1 x2)

instance (Applicative f, Monoid a) => Applicative (TreeForest f a) where
  pure b =
    TreeForest (pure (Left b))
  TreeForest x1 <*> TreeForest x2 =
    let combine (Left f) (Left x) =
          Left (f x)
        combine (Left f) (Right tx) =
          Right (fmap f tx)
        combine (Right tf) (Left x) =
          Right (fmap ($ x) tf)
        combine (Right tf) (Right tx) =
          Right (tf <*> tx)
    in  TreeForest (liftA2 combine x1 x2)

instance Foldable f => Bifoldable (TreeForest f) where
  bifoldMap f g (TreeForest x) =
    foldMap (either g (bifoldMap f g)) x

instance Foldable1 f => Bifoldable1 (TreeForest f) where
  bifoldMap1 f g (TreeForest x) =
    foldMap1 (either g (bifoldMap1 f g)) x

instance Foldable f => Foldable (TreeForest f a) where
  foldMap f (TreeForest x) =
    foldMap (either f (foldMap f)) x

instance Foldable1 f => Foldable1 (TreeForest f a) where
  foldMap1 f (TreeForest x) =
    foldMap1 (either f (foldMap1 f)) x

instance Traversable f => Bitraversable (TreeForest f) where
  bitraverse f g (TreeForest x) =
    TreeForest <$> traverse (either (fmap Left . g) (fmap Right . bitraverse f g)) x

instance Traversable1 f => Bitraversable1 (TreeForest f) where
  bitraverse1 f g (TreeForest x) =
    TreeForest <$> traverse1 (either (fmap Left . g) (fmap Right . bitraverse1 f g)) x

instance Traversable f => Traversable (TreeForest f a) where
  traverse f (TreeForest x) =
    TreeForest <$> traverse (either (fmap Left . f) (fmap Right . traverse f)) x

instance Traversable1 f => Traversable1 (TreeForest f a) where
  traverse1 f (TreeForest x) =
    TreeForest <$> traverse1 (either (fmap Left . f) (fmap Right . traverse1 f)) x

class HasTreeForest x f a b | x -> f a b where
  treeForest ::
    Lens' x (TreeForest f a b)

instance HasTreeForest (TreeForest f a b) f a b where
  treeForest =
    id

class AsTreeForest x f a b | x -> f a b where
  _TreeForest ::
    Prism' x (TreeForest f a b)

instance AsTreeForest (TreeForest f a b) f a b where
  _TreeForest =
    id

data Tree f a b =
  Tree a (TreeForest f a b)

type Tree' f a =
  Tree f a a

type TreeList a b =
  Tree [] a b

type TreeList' a =
  TreeList a a

type Tree1 a b =
  Tree Identity a b

type Tree1' a =
  Tree1 a a

instance Eq1 f => Eq2 (Tree f) where
  liftEq2 f g (Tree a t1) (Tree b t2) =
    f a b &&
    liftEq2 f g t1 t2

instance Ord1 f => Ord2 (Tree f) where
  liftCompare2 f g (Tree a t1) (Tree b t2) =
    f a b <>
    liftCompare2 f g t1 t2

instance Show1 f => Show2 (Tree f) where
  liftShowsPrec2 spA slA spB slB d (Tree a t) =
    showsBinaryWith
      spA
      (liftShowsPrec2 spA slA spB slB)
      "Tree"
      d
      a
      t

instance (Eq a, Eq1 f) => Eq1 (Tree f a) where
  liftEq =
    liftEq2 (==)

instance (Ord a, Ord1 f) => Ord1 (Tree f a) where
  liftCompare =
    liftCompare2 compare

instance (Show a, Show1 f) => Show1 (Tree f a) where
  liftShowsPrec =
    liftShowsPrec2 showsPrec showList

instance (Eq a, Eq1 f, Eq b) => Eq (Tree f a b) where
  (==) =
    liftEq (==)

instance (Ord a, Ord1 f, Ord b) => Ord (Tree f a b) where
  compare =
    liftCompare compare

instance (Show a, Show1 f, Show b) => Show (Tree f a b) where
  showsPrec =
    liftShowsPrec showsPrec shows

instance Functor f => Bifunctor (Tree f) where
  bimap f g (Tree a t) =
    Tree (f a) (bimap f g t)

instance Functor f => Functor (Tree f a) where
  fmap =
    bimap id

instance (Apply f, Semigroup a) => Apply (Tree f a) where
  Tree a1 t1 <.> Tree a2 t2 =
    Tree (a1 <> a2) (t1 <.> t2)

-- |
--
-- >>> Tree "a" (TreeForest []) <*> Tree "b" (TreeForest []) :: TreeList String String
-- Tree "ab" (TreeForest [])
--
-- >>> Tree "a" (TreeForest [Left Prelude.reverse]) <*> Tree "b" (TreeForest [Left "xyz"]) :: TreeList String String
-- Tree "ab" (TreeForest [Left "zyx"])
--
-- >>> Tree "a" (TreeForest [Left Prelude.reverse]) <*> Tree "b" (TreeForest [Left "xyz", makeChild "c" [Left "pqr"], makeChild "d" [Left "mno"]]) :: TreeList String String
-- Tree "ab" (TreeForest [Left "zyx",Right (Tree "c" (TreeForest [Left "rqp"])),Right (Tree "d" (TreeForest [Left "onm"]))])
instance (Applicative f, Monoid a) => Applicative (Tree f a) where
  pure =
    Tree mempty . pure
  Tree a1 t1 <*> Tree a2 t2 =
    Tree (a1 <> a2) (t1 <*> t2)

instance Foldable f => Bifoldable (Tree f) where
  bifoldMap f g (Tree a t) =
    f a <> bifoldMap f g t

instance Foldable1 f => Bifoldable1 (Tree f) where
  bifoldMap1 f g (Tree a t) =
    f a <> bifoldMap1 f g t

instance Foldable f => Foldable (Tree f a) where
  foldMap f (Tree _ t) =
    foldMap f t

instance Foldable1 f => Foldable1 (Tree f a) where
  foldMap1 f (Tree _ t) =
    foldMap1 f t

instance Traversable f => Bitraversable (Tree f) where
  bitraverse f g (Tree a t) =
    Tree <$> f a <*> bitraverse f g t

instance Traversable1 f => Bitraversable1 (Tree f) where
  bitraverse1 f g (Tree a t) =
    Tree <$> f a <.> bitraverse1 f g t

instance Traversable f => Traversable (Tree f a) where
  traverse f (Tree a t) =
    Tree a <$> traverse f t

instance Traversable1 f => Traversable1 (Tree f a) where
  traverse1 f (Tree a t) =
    Tree a <$> traverse1 f t

instance Traversable f => Plated (Tree f a b) where
  plate f (Tree a (TreeForest x)) =
    Tree a . TreeForest <$> traverse
      ( either
          (pure . Left)
          (fmap Right . f)
      ) x

treeForest' ::
  Lens
    (Tree f a b)
    (Tree f' a b')
    (TreeForest f a b)
    (TreeForest f' a b')
treeForest' f (Tree a t) =
  fmap (Tree a) (f t)

treeSubForest ::
  Traversable f =>
  Traversal
    (Tree f a b)
    (Tree f a b')
    (Either b (Tree f a b))
    (Either b' (Tree f a b'))
treeSubForest =
  treeForest' . _Wrapped . traverse

treeLeaves ::
  Traversable f =>
  Traversal'
    (Tree f a b)
    b
treeLeaves =
  treeSubForest . _Left

treeForestChildren ::
  Traversable f =>
  Traversal'
    (Tree f a b)
    (Tree f a b)
treeForestChildren =
  treeSubForest . _Right

class HasTree x f a b | x -> f a b where
  tree ::
    Lens' x (Tree f a b)
  {-# INLINE treeLabel #-}
  treeLabel ::
    Lens' x a
  treeLabel =
    tree . treeLabel

-- |
--
-- >>> view treeLabel (Tree "a" (TreeForest []))
-- "a"
--
-- >>> view treeForest (Tree "a" (TreeForest []))
-- TreeForest []
--
-- >>> view treeForest (Tree "b" (TreeForest [Left "xyz", makeChild "c" [Left "pqr"], makeChild "d" [Left "mno"]]))
-- TreeForest [Left "xyz",Right (Tree "c" (TreeForest [Left "pqr"])),Right (Tree "d" (TreeForest [Left "mno"]))]
instance HasTree (Tree f a b) f a b where
  tree =
    id
  {-# INLINE treeLabel #-}
  treeLabel f (Tree a t) =
    fmap (`Tree` t) (f a)

instance HasTreeForest (Tree f a b) f a b where
  treeForest =
    treeForest'

class AsTree x f a b | x -> f a b where
  _Tree ::
    Prism' x (Tree f a b)

instance AsTree (Tree f a b) f a b where
  _Tree =
    id

-- |
--
-- >>> dfs (Tree 1 (TreeForest []))
-- Left 1 :| []
--
-- >>> dfs (Tree 1 (TreeForest [Left 2]))
-- Left 1 :| [Right 2]
--
-- >>> dfs (Tree 1 (TreeForest [Left 2, makeChild 3 []]))
-- Left 1 :| [Right 2,Left 3]
--
-- >>> dfs (Tree 1 (TreeForest [Left 2, makeChild 3 [], Left 4]))
-- Left 1 :| [Right 2,Left 3,Right 4]
--
-- >>> dfs (Tree 1 (TreeForest [Left 2, makeChild 3 [Left 5], Left 4]))
-- Left 1 :| [Right 2,Left 3,Right 5,Right 4]
--
-- >>> dfs (Tree 1 (TreeForest [Left 2, makeChild 3 [Left 5], Left 4, makeChild 6 []]))
-- Left 1 :| [Right 2,Left 3,Right 5,Right 4,Left 6]
dfs ::
  Foldable f =>
  Tree f a b ->
  NonEmpty (Either a b)
dfs (Tree a t) =
  Left a :| foldMap (either (\b -> [Right b]) (toList . dfs)) (view _Wrapped t)

-- |
--
-- >>> bfs (Tree 1 (TreeForest []))
-- Left 1 :| []
--
-- >>> bfs (Tree 1 (TreeForest [Left 2]))
-- Left 1 :| [Right 2]
--
-- >>> bfs (Tree 1 (TreeForest [Left 2, makeChild 3 []]))
-- Left 1 :| [Right 2,Left 3]
--
-- >>> bfs (Tree 1 (TreeForest [Left 2, makeChild 3 [], Left 4]))
-- Left 1 :| [Right 2,Right 4,Left 3]
--
-- >>> bfs (Tree 1 (TreeForest [Left 2, makeChild 3 [Left 5], Left 4]))
-- Left 1 :| [Right 2,Right 4,Left 3,Right 5]
--
-- >>> bfs (Tree 1 (TreeForest [Left 2, makeChild 3 [Left 5], Left 4, makeChild 6 []]))
-- Left 1 :| [Right 2,Right 4,Left 3,Right 5,Left 6]
bfs ::
  Foldable f =>
  Tree f a b
  -> NonEmpty (Either a b)
bfs root =
  let go (Tree a t :| rest) =
        let (leaves, c) =
              foldMap (either (\b -> ([Right b], [])) (\tr -> ([], [tr]))) (view _Wrapped t)
        in  case nonEmpty (rest <> c) of
              Nothing -> Left a :| leaves
              Just q  -> Left a :| (leaves <> toList (go q))
  in  go (root :| [])

-- |
--
-- >>> makeChild 1 []
-- Right (Tree 1 (TreeForest []))
--
-- >>> makeChild 1 [Left "a"]
-- Right (Tree 1 (TreeForest [Left "a"]))
--
-- >>> makeChild 1 [Left "a", makeChild 2 []]
-- Right (Tree 1 (TreeForest [Left "a",Right (Tree 2 (TreeForest []))]))
makeChild ::
  a
  -> f (Either b (Tree f a b))
  -> Either x (Tree f a b)
makeChild a t =
  Right (Tree a (TreeForest t))

-- |
--
-- >>> makeLeaves 1 []
-- Tree 1 (TreeForest [])
--
-- >>> makeLeaves 1 [makeLeaves 2 []]
-- Tree 1 (TreeForest [Left (Tree 2 (TreeForest []))])
--
-- >>> makeLeaves 1 [makeLeaves 2 [makeChild 3 []]]
-- Tree 1 (TreeForest [Left (Tree 2 (TreeForest [Left (Right (Tree 3 (TreeForest [])))]))])
makeLeaves ::
  Functor f =>
  a
  -> f b
  -> Tree f a b
makeLeaves a bs =
  Tree a (TreeForest (Left <$> bs))

-- |
--
-- >>> makeChildren 1 []
-- Tree 1 (TreeForest [])
--
-- >>> makeChildren 1 [makeChildren 2 []]
-- Tree 1 (TreeForest [Right (Tree 2 (TreeForest []))])
--
-- >>> makeChildren 1 [makeChildren 2 [], makeLeaves 3 []]
-- Tree 1 (TreeForest [Right (Tree 2 (TreeForest [])),Right (Tree 3 (TreeForest []))])
makeChildren ::
  Functor f =>
  a
  -> f (Tree f a b)
  -> Tree f a b
makeChildren a cs =
  Tree a (TreeForest (Right <$> cs))

-- |
--
-- >>> view baseTree (Tree 1 (TreeForest []))
-- Node {rootLabel = 1, subForest = []}
--
-- >>> view baseTree (Tree 1 (TreeForest [Left 2, makeChild 3 [Left 5], Left 4, makeChild 6 []]))
-- Node {rootLabel = 1, subForest = [Node {rootLabel = 2, subForest = []},Node {rootLabel = 3, subForest = [Node {rootLabel = 5, subForest = []}]},Node {rootLabel = 4, subForest = []},Node {rootLabel = 6, subForest = []}]}
--
-- >>> review baseTree (Tree.Node 1 [])
-- Tree 1 (TreeForest [])
--
-- >>> review baseTree (Tree.Node 1 [Tree.Node 2 [],Tree.Node 3 [Tree.Node 5 []],Tree.Node 4 [],Tree.Node 6 []])
-- Tree 1 (TreeForest [Left 2,Right (Tree 3 (TreeForest [Left 5])),Left 4,Left 6])
baseTree ::
  Iso' (TreeList' a) (Tree.Tree a)
baseTree =
  iso
    (
      let go (Tree a t) =
            Tree.Node a (fmap (either pure go) (view _Wrapped t))
      in  go)
    (
      let perNode (Tree.Node a []) =
            Left a
          perNode tr@(Tree.Node _ (_:_)) =
            Right tr
          go (Tree.Node a t) =
            Tree a (TreeForest (fmap (fmap go . perNode) t))
      in  go)
