-- | Miscellaneous utility functions
module Data.Generics.Fixplate.Misc where

--------------------------------------------------------------------------------

import Prelude hiding (mapM,mapM_)

import Data.List ( sortBy , groupBy )
import Data.Ord

import Data.Traversable

--import Control.Monad (liftM)
--import Control.Monad.Trans.State

--------------------------------------------------------------------------------
        
data Two a b 
  =  Empty 
  |  One a
  |  Two b
  deriving Show
  
data Both a b 
  =  None
  |  First a
  |  Both  a b
  deriving Show

--------------------------------------------------------------------------------

equating :: Eq b => (a -> b) -> a -> a -> Bool
equating f x y = f x == f y

groupSortOn :: Ord b => (a -> b) -> [a] -> [[a]]
groupSortOn f xs = groupBy (equating f) $ sortBy (comparing f) xs

mapGroupSortOn :: Ord b => (a -> b) -> (a -> c) -> [a] -> [(b,[c])]
mapGroupSortOn f g = mapGroupSortOn' f (map g)

mapGroupSortOn' :: Ord b => (a -> b) -> ([a] -> c) -> [a] -> [(b,c)]
mapGroupSortOn' f g xs = map h $ groupBy (equating f) $ sortBy (comparing f) xs where
  h ys = (f (head ys), g ys)

--------------------------------------------------------------------------------

unsafe :: (a -> Maybe b) -> String -> a -> b
unsafe safe msg loc = case safe loc of
  Just new -> new
  Nothing  -> error msg
  
--------------------------------------------------------------------------------

app_prec :: Int
app_prec = 10

--------------------------------------------------------------------------------

(<#>) :: (a -> b) -> (c -> d) -> (a,c) -> (b,d)
(f <#> g) (x,y) = (f x, g y)

first :: (a -> b) -> (a,c) -> (b,c)
first f (x,y) = (f x, y)

second :: (b -> c) -> (a,b) -> (a,c)
second g (x,y) = (x, g y)

--------------------------------------------------------------------------------

tillNothing :: (a -> Maybe a) -> a -> a
tillNothing f = go where 
  go x = case f x of { Nothing -> x ; Just y -> go y }
  
chain :: [a -> Maybe a] -> a -> Maybe a
chain [] x = return x
chain (f:fs) x = (f x) >>= chain fs 
  
chainJust :: [a -> Maybe a] -> a -> a
chainJust fs x = case chain fs x of
  Nothing -> error "chainJust: Nothing"
  Just y  -> y
  
--------------------------------------------------------------------------------  

iterateN :: Int -> (a -> a) -> a -> a
iterateN n f = go n where 
  go 0 x = x
  go n x = go (n-1) (f x)

--------------------------------------------------------------------------------  

mapM_ :: (Traversable t, Monad m) => (a -> m ()) -> t a -> m ()
mapM_ act t = do
  _ <- mapM act t
  return ()

mapAccumM :: (Traversable t, Monad m) => (a -> b -> m (a, c)) -> a -> t b -> m (a, t c)
mapAccumM act x0 t = runStateT (mapM (StateT . flip act) t) x0 where

--------------------------------------------------------------------------------  

newtype StateT s m a = StateT { runStateT :: s -> m (s,a) }

instance (Monad m) => Monad (StateT s m) where
  return a = state $ \s -> (s,a)
  m >>= k  = StateT $ \s -> do
    ~(s', a) <- runStateT m s
    runStateT (k a) s'
  fail str = StateT $ \_ -> fail str

state :: Monad m => (s -> (s,a)) -> StateT s m a 
state f = StateT (return . f)

sget :: (Monad m) => StateT s m s
sget = state $ \s -> (s,s)

sput :: (Monad m) => s -> StateT s m ()
sput s = state $ \_ -> (s,())

smodify :: (Monad m) => (s -> s) -> StateT s m ()
smodify f = state $ \s -> (f s,())

--------------------------------------------------------------------------------