{-# LANGUAGE FlexibleInstances, FunctionalDependencies, MultiParamTypeClasses, RankNTypes, UndecidableInstances #-}

-- | The 'List' class and actions for lists

module Data.List.Class (
  -- | The List typeclass
  List (..),
  -- | List operations for MonadPlus
  cons, fromList, filter, repeat,
  -- | Standard list operations
  takeWhile, genericTake, scanl,
  sequence, sequence_, transpose,
  zip, zipWith,
  -- | Non standard List operations
  foldlL, toList, execute, lengthL, lastL,
  -- | Convert between List types
  convList, transformListMonad, liftListMonad
  ) where

import Control.Monad (MonadPlus(..), ap, join, liftM)
import Control.Monad.Identity (Identity(..))
import Control.Monad.ListT (ListT(..), ListItem(..), foldrListT)
import Control.Monad.Trans (MonadTrans(..))
import Data.Function (fix)
import Prelude hiding (
  filter, repeat, scanl, sequence, sequence_, takeWhile, zip, zipWith)

-- | A class for list types.
-- Every list has an underlying monad.
class (MonadPlus l, Monad m) => List l m | l -> m where
  -- | Transform an action returning a list to the returned list
  --
  -- > > joinL $ Identity "hello"
  -- > "hello"
  joinL :: m (l b) -> l b
  -- | foldr for 'List's.
  -- the result and 'right side' values are monadic actions.
  foldrL :: (a -> m b -> m b) -> m b -> l a -> m b
  foldrL consFunc nilFunc = foldrL consFunc nilFunc . toListT
  -- | Convert to a 'ListT'.
  --
  -- Can be done with a foldrL but included in type-class for efficiency.
  toListT :: l a -> ListT m a
  toListT = convList
  -- | Convert from a 'ListT'.
  --
  -- Can be done with a foldrL but included in type-class for efficiency.
  fromListT :: ListT m a -> l a
  fromListT = convList

instance List [] Identity where
  joinL = runIdentity
  foldrL = foldr
  toListT = fromList

instance Monad m => List (ListT m) m where
  joinL = ListT . (>>= runListT)
  foldrL = foldrListT
  toListT = id
  fromListT = id

-- | Prepend an item to a 'MonadPlus'
cons :: MonadPlus m => a -> m a -> m a
cons = mplus . return

-- | Convert a list to a 'MonadPlus'
--
-- > > fromList [] :: Maybe Int
-- > Nothing
-- > > fromList [5] :: Maybe Int
-- > Just 5
fromList :: MonadPlus m => [a] -> m a
fromList = foldr (mplus . return) mzero

-- | Convert between lists with the same underlying monad
convList :: (List l m, List k m) => l a -> k a
convList =
  joinL . foldrL step (return mzero)
  where
    step x = return . cons x . joinL

-- | filter for any MonadPlus
--
-- > > filter (> 5) (Just 3)
-- > Nothing
filter :: MonadPlus m => (a -> Bool) -> m a -> m a
filter cond =
  (>>= f)
  where
    f x
      | cond x = return x
      | otherwise = mzero

-- for foldlL and scanl
foldlL' :: List l m =>
  (a -> m c -> c) -> (a -> c) -> (a -> b -> a) -> a -> l b -> c
foldlL' joinVals atEnd step startVal =
  t startVal . foldrL astep (return atEnd)
  where
    astep x rest = return $ (`t` rest) . (`step` x)
    t cur = joinVals cur . (`ap` return cur)

-- | An action to do foldl for 'List's
foldlL :: List l m => (a -> b -> a) -> a -> l b -> m a
foldlL step startVal =
  foldlL' (const join) id astep (return startVal)
  where
    astep rest x = liftM (`step` x) rest

scanl :: List l m => (a -> b -> a) -> a -> l b -> l a
scanl =
  foldlL' consJoin $ const mzero
  where
    consJoin cur = cons cur . joinL

genericTake :: (Integral i, List l m) => i -> l a -> l a
genericTake count
  | count <= 0 = const mzero
  | otherwise = foldlL' joinStep (const mzero) next Nothing
  where
    next Nothing x = Just (count, x)
    next (Just (i, _)) y = Just (i - 1, y)
    joinStep Nothing = joinL
    joinStep (Just (1, x)) = const $ return x
    joinStep (Just (_, x)) = cons x . joinL

-- | Execute the monadic actions in a 'List'
execute :: List l m => l a -> m ()
execute = foldlL const ()

sequence :: List l m => l (m a) -> m (l a)
sequence =
  foldrL consFunc (return mzero)
  where
    consFunc action rest = do
      x <- action
      return . cons x . joinL $ rest

sequence_ :: List l m => l (m a) -> m ()
sequence_ = execute . joinL . sequence

takeWhile :: List l m => (a -> Bool) -> l a -> l a
takeWhile cond =
  joinL . foldrL step (return mzero)
  where
    step x
      | cond x = return . cons x . joinL
      | otherwise = const $ return mzero

-- | An action to transform a 'List' to a list
--
-- > > runIdentity $ toList "hello!"
-- > "hello!"
toList :: List l m => l a -> m [a]
toList =
  foldrL step $ return []
  where
    step = liftM . (:)

-- | Consume a list (execute its actions) and return its length
--
-- > > runIdentity $ lengthL [1,2,3]
-- > 3
lengthL :: (Integral i, List l m) => l a -> m i
lengthL = foldlL (const . (+ 1)) 0

-- | Transform the underlying monad of a list given a way to transform the monad
--
-- > > import Data.List.Tree (bfs)
-- > > bfs (transformListMonad (\(Identity x) -> [x, x]) "hey" :: ListT [] Char)
-- > "hheeeeyyyyyyyy"
transformListMonad :: (List l m, List k s) =>
  (forall x. m x -> s x) -> l a -> k a
transformListMonad trans =
  t . foldrL step (return mzero)
  where
    t = joinL . trans
    step x = return . cons x . t

-- | Lift the underlying monad of a list and transform it to a ListT.
--
-- Doing plain 'transformListMonad lift' instead doesn't give the compiler
-- the same knowledge about the types.
liftListMonad ::
  (MonadTrans t, Monad (t m), List l m) =>
  l a -> ListT (t m) a
liftListMonad = transformListMonad lift

zip :: List l m => l a -> l b -> l (a, b)
zip as bs =
  r0 (toListT as) (toListT bs)
  where
    r0 xx yy =
      joinL $ do
        xi <- runListT xx
        case xi of
          Nil -> return mzero
          Cons x xs -> r1 x xs yy
    r1 x xs yy = do
      yi <- runListT yy
      return $ case yi of
        Nil -> mzero
        Cons y ys ->
          cons (x, y) $ r0 xs ys

-- zipWith based on zip and not vice versa,
-- because the other way around hlint compains "use zip".
zipWith :: List l m => (a -> b -> c) -> l a -> l b -> l c
zipWith func as = liftM (uncurry func) . zip as

-- | Consume all items and return the last one
--
-- > > runIdentity $ lastL "hello"
-- > 'o'
lastL :: List l m => l a -> m a
lastL = foldlL (const id) undefined

repeat :: MonadPlus m => a -> m a
repeat = fix . cons

transpose :: List l m => l (l a) -> l (l a)
transpose matrix =
  joinL $ toList matrix >>= r . map toListT
  where
    r xs = do
      items <- mapM runListT xs
      return $ case filter isCons items of
        [] -> mzero
        citems ->
          cons (fromList (map headL citems)) .
          joinL . r $ map tailL citems
    isCons Nil = False
    isCons _ = True