-- | Miscellaneous utility functions

{-# LANGUAGE CPP #-}
module Data.Generics.Fixplate.Misc where

--------------------------------------------------------------------------------

import Prelude hiding ( mapM , mapM_ )

import Data.List ( sortBy , groupBy )
import Data.Ord

import Data.Traversable

import Control.Applicative ()
import Control.Monad ( ap , liftM )

--import Control.Monad.Trans.State

--------------------------------------------------------------------------------

data Two a b
  =  Empty
  |  One a
  |  Two b
  deriving Show

data Both a b
  =  None
  |  First a
  |  Both  a b
  deriving Show

--------------------------------------------------------------------------------

equating :: Eq b => (a -> b) -> a -> a -> Bool
equating f x y = f x == f y

groupSortOn :: Ord b => (a -> b) -> [a] -> [[a]]
groupSortOn f xs = groupBy (equating f) $ sortBy (comparing f) xs

mapGroupSortOn :: Ord b => (a -> b) -> (a -> c) -> [a] -> [(b,[c])]
mapGroupSortOn f g = mapGroupSortOn' f (map g)

mapGroupSortOn' :: Ord b => (a -> b) -> ([a] -> c) -> [a] -> [(b,c)]
mapGroupSortOn' f g xs = map h $ groupBy (equating f) $ sortBy (comparing f) xs where
  h ys = (f (head ys), g ys)

--------------------------------------------------------------------------------

unsafe :: (a -> Maybe b) -> String -> a -> b
unsafe safe msg loc = case safe loc of
  Just new -> new
  Nothing  -> error msg

--------------------------------------------------------------------------------

app_prec :: Int
app_prec = 10

--------------------------------------------------------------------------------

(<#>) :: (a -> b) -> (c -> d) -> (a,c) -> (b,d)
(f <#> g) (x,y) = (f x, g y)

first :: (a -> b) -> (a,c) -> (b,c)
first f (x,y) = (f x, y)

second :: (b -> c) -> (a,b) -> (a,c)
second g (x,y) = (x, g y)

--------------------------------------------------------------------------------

tillNothing :: (a -> Maybe a) -> a -> a
tillNothing f = go where
  go x = case f x of { Nothing -> x ; Just y -> go y }

chain :: [a -> Maybe a] -> a -> Maybe a
chain [] x = return x
chain (f:fs) x = (f x) >>= chain fs

chainJust :: [a -> Maybe a] -> a -> a
chainJust fs x = case chain fs x of
  Nothing -> error "chainJust: Nothing"
  Just y  -> y

--------------------------------------------------------------------------------  

iterateN :: Int -> (a -> a) -> a -> a
iterateN n f = go n where
  go 0 x = x
  go n x = go (n-1) (f x)

--------------------------------------------------------------------------------  

mapM_ :: (Traversable t, Monad m) => (a -> m ()) -> t a -> m ()
mapM_ act t = do
  _ <- mapM act t
  return ()

mapAccumM :: (Traversable t, Monad m) => (a -> b -> m (a, c)) -> a -> t b -> m (a, t c)
mapAccumM user x0 t = liftM swap $ runStateT (mapM action t) x0 where
  action x = StateT $ \acc -> do
    (acc', y) <- user acc x
    return (y, acc')

swap :: (a,b) -> (b,a)
swap (x,y) = (y,x)

--------------------------------------------------------------------------------  

newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }

instance (Functor m) => Functor (StateT s m) where
  fmap f m = StateT $ \s -> fmap (\ ~(a, s') -> (f a, s')) $ runStateT m s

instance (Functor m, Monad m) => Applicative (StateT s m) where
  pure  = return
  (<*>) = ap

instance (Monad m) => Monad (StateT s m) where
  return a = state $ \s -> (a, s)
  m >>= k  = StateT $ \s -> do
    ~(a, s') <- runStateT m s
    runStateT (k a) s'

#if MIN_VERSION_base(4,13,0)    
instance MonadFail m => MonadFail (StateT s m) where
  fail str = StateT $ \_ -> fail str
#else
  fail str = StateT $ \_ -> fail str
#endif

state :: (Monad m) => (s -> (a,s)) -> StateT s m a
state f = StateT (return . f)

sget :: (Monad m) => StateT s m s
sget = state $ \s -> (s,s)

sput :: (Monad m) => s -> StateT s m ()
sput s = state $ \_ -> ((),s)

smodify :: (Monad m) => (s -> s) -> StateT s m ()
smodify f = state $ \s -> ((), f s)

--------------------------------------------------------------------------------