{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# OPTIONS_GHC -fno-warn-tabs #-}

module Data.TreeMap.Strict.Zipper where

import           Control.Monad (Monad(..), (>=>))
import           Control.Applicative (Applicative(..), Alternative(..))
import           Data.Bool (Bool)
import           Data.Data (Data)
import           Data.Eq (Eq)
import           Data.Function (($), (.))
import           Data.Functor ((<$>))
import           Data.Int (Int)
import qualified Data.List as List
import           Data.List.NonEmpty (NonEmpty(..))
import qualified Data.Map.Strict as Map
import           Data.Maybe (Maybe(..), maybe, maybeToList)
import           Data.Ord (Ord(..))
import           Data.Tuple (fst)
import           Data.Typeable (Typeable)
import           Text.Show (Show(..))

import           Data.TreeMap.Strict (TreeMap(..))
import qualified Data.TreeMap.Strict as TreeMap

-- * Type 'Zipper'

data Zipper k a
 =   Zipper
 {   zipper_path :: [Zipper_Step k a]
 ,   zipper_curr :: TreeMap k a
 } deriving (Data, Eq, Show, Typeable)

zipper :: TreeMap k a -> Zipper k a
zipper = Zipper []

zipper_root :: Ord k => Zipper k a -> TreeMap k a
zipper_root = zipper_curr . List.last . zipper_ancestor_or_self

path_of_zipper :: Zipper k x -> [k]
path_of_zipper z =
	fst . zipper_step_self <$>
	List.reverse (zipper_path z)

-- * Type 'Zipper_Step'

data Zipper_Step k a
 =   Zipper_Step
 {   zipper_step_prec :: TreeMap k a
 ,   zipper_step_self :: (k, TreeMap.Node k a)
 ,   zipper_step_foll :: TreeMap k a
 } deriving (Data, Eq, Show, Typeable)

-- * Axis

-- | Collect all 'Zipper's along a given axis,
--   including the first 'Zipper'.
zipper_collect :: (z -> Maybe z) -> z -> [z]
zipper_collect f z = z : maybe [] (zipper_collect f) (f z)

-- | Collect all 'Zipper's along a given axis,
--   excluding the first 'Zipper'.
zipper_collect_without_self :: (z -> Maybe z) -> z -> [z]
zipper_collect_without_self f z = maybe [] (zipper_collect f) (f z)

-- ** Axis self

zipper_self :: Zipper k a -> TreeMap.Node k a
zipper_self z =
	case z of
	 Zipper{ zipper_path=
	         Zipper_Step{zipper_step_self=(_, nod)}
	         : _ } -> nod
	 _ -> TreeMap.node_empty

-- ** Axis child

zipper_child :: Ord k => Zipper k a -> [Zipper k a]
zipper_child z =
	maybeToList (zipper_child_first z)
	>>= zipper_collect zipper_foll

zipper_child_lookup
 :: (Ord k, Alternative f)
 => k -> Zipper k a -> f (Zipper k a)
zipper_child_lookup k (Zipper path (TreeMap m)) =
	case Map.splitLookup k m of
	 (_, Nothing, _) -> empty
	 (ps, Just s, fs) ->
		pure Zipper
		 { zipper_path = Zipper_Step (TreeMap ps) (k, s) (TreeMap fs) : path
		 , zipper_curr = TreeMap.node_descendants s
		 }

zipper_child_first :: Alternative f => Zipper k a -> f (Zipper k a)
zipper_child_first (Zipper path (TreeMap m)) =
	case Map.minViewWithKey m of
	 Nothing -> empty
	 Just ((k', s'), fs') ->
		pure Zipper
		 { zipper_path = Zipper_Step TreeMap.empty (k', s') (TreeMap fs') : path
		 , zipper_curr = TreeMap.node_descendants s'
		 }

zipper_child_last :: Alternative f => Zipper k a -> f (Zipper k a)
zipper_child_last (Zipper path (TreeMap m)) =
	case Map.maxViewWithKey m of
	 Nothing -> empty
	 Just ((k', s'), ps') ->
		pure Zipper
		 { zipper_path = Zipper_Step (TreeMap ps') (k', s') TreeMap.empty : path
		 , zipper_curr = TreeMap.node_descendants s'
		 }

-- ** Axis ancestor

zipper_ancestor :: Ord k => Zipper k a -> [Zipper k a]
zipper_ancestor = zipper_collect_without_self zipper_parent

zipper_ancestor_or_self :: Ord k => Zipper k a -> [Zipper k a]
zipper_ancestor_or_self = zipper_collect zipper_parent

-- ** Axis descendant

zipper_descendant_or_self :: Ord k => Zipper k a -> [Zipper k a]
zipper_descendant_or_self =
	collect_child []
	where
		collect_child acc z =
			z : maybe acc
			 (collect_foll acc)
			 (zipper_child_first z)
		collect_foll  acc z =
			collect_child
			 (maybe acc
				 (collect_foll acc)
				 (zipper_foll z)
			 ) z

zipper_descendant_or_self_reverse :: Ord k => Zipper k a -> [Zipper k a]
zipper_descendant_or_self_reverse z =
	z : List.concatMap
	 zipper_descendant_or_self_reverse
	 (List.reverse $ zipper_child z)

zipper_descendant :: Ord k => Zipper k a -> [Zipper k a]
zipper_descendant = List.tail . zipper_descendant_or_self

zipper_descendant_lookup
 :: (Ord k, Alternative f, Monad f)
 => TreeMap.Path k -> Zipper k a -> f (Zipper k a)
zipper_descendant_lookup (k:|ks) =
	case ks of
	 []     -> zipper_child_lookup k
	 k':ks' -> zipper_child_lookup k >=> zipper_descendant_lookup (k':|ks')

-- ** Axis preceding

zipper_prec :: (Ord k, Alternative f) => Zipper k a -> f (Zipper k a)
zipper_prec (Zipper path _curr) =
	case path of
	 [] -> empty
	 Zipper_Step (TreeMap ps) (k, s) (TreeMap fs):steps ->
		case Map.maxViewWithKey ps of
		 Nothing -> empty
		 Just ((k', s'), ps') ->
			pure Zipper
			 { zipper_path = Zipper_Step (TreeMap ps')
			                             (k', s')
			                             (TreeMap $ Map.insert k s fs)
			                 : steps
			 , zipper_curr = TreeMap.node_descendants s'
			 }

zipper_preceding :: Ord k => Zipper k a -> [Zipper k a]
zipper_preceding =
	zipper_ancestor_or_self >=>
	zipper_preceding_sibling >=>
	zipper_descendant_or_self_reverse

zipper_preceding_sibling :: Ord k => Zipper k a -> [Zipper k a]
zipper_preceding_sibling = zipper_collect_without_self zipper_prec

-- ** Axis following

zipper_foll :: (Ord k, Alternative f) => Zipper k a -> f (Zipper k a)
zipper_foll (Zipper path _curr) =
	case path of
	 [] -> empty
	 Zipper_Step (TreeMap ps) (k, s) (TreeMap fs):steps ->
		case Map.minViewWithKey fs of
		 Nothing -> empty
		 Just ((k', s'), fs') ->
			pure Zipper
			 { zipper_path = Zipper_Step (TreeMap $ Map.insert k s ps)
			                             (k', s')
			                             (TreeMap fs')
			                 : steps
			 , zipper_curr = TreeMap.node_descendants s'
			 }

zipper_following :: Ord k => Zipper k a -> [Zipper k a]
zipper_following =
	zipper_ancestor_or_self >=>
	zipper_following_sibling >=>
	zipper_descendant_or_self

zipper_following_sibling :: Ord k => Zipper k a -> [Zipper k a]
zipper_following_sibling = zipper_collect_without_self zipper_foll

-- ** Axis parent

zipper_parent :: (Ord k, Alternative f) => Zipper k a -> f (Zipper k a)
zipper_parent (Zipper path curr) =
	case path of
	 [] -> empty
	 Zipper_Step (TreeMap ps) (k, s) (TreeMap fs):steps ->
		let nod = TreeMap.node (TreeMap.node_value s) curr in
		pure Zipper
		 { zipper_path = steps
		 , zipper_curr = TreeMap $ Map.union ps $ Map.insert k nod fs
		 }

-- ** Filter

zipper_filter
 :: (Zipper k a -> [Zipper k a])
 -> (Zipper k a -> Bool)
 -> (Zipper k a -> [Zipper k a])
zipper_filter axis p z = List.filter p (axis z)
infixl 5 `zipper_filter`

zipper_at :: Alternative f
 => (Zipper k a -> [Zipper k a]) -> Int
 -> (Zipper k a -> f (Zipper k a))
zipper_at axis n z = case List.drop n (axis z) of {[] -> empty; a:_ -> pure a}
infixl 5 `zipper_at`

zipper_null
 :: (Zipper k a -> [Zipper k a])
 -> Zipper k a -> Bool
zipper_null axis = List.null . axis