{-|
  Functions on lists and sequences.
  Some of the functions follow the style of Data.Random.Extras 
  (from the random-extras package), but are written for use with
  PRNGs from the "mwc-random" package rather than from the "random-fu" package.
-}

module Language.Hakaru.Util.Extras where

import qualified Data.Sequence as S
import qualified System.Random.MWC as MWC
import Data.Maybe
import qualified Data.Foldable as F

extract :: S.Seq a -> Int -> Maybe (S.Seq a, a)
extract s i | S.null r = Nothing
            | otherwise  = Just (a S.>< c, b)
    where (a, r) = S.splitAt i s 
          (b S.:< c) = S.viewl r

randomExtract :: S.Seq a -> MWC.GenIO -> IO (Maybe (S.Seq a, a))
randomExtract s g = do
  i <- MWC.uniformR (0, S.length s - 1) g
  return $ extract s i

{-| 
  Given a sequence, return a *sorted* sequence of
  n randomly selected elements from *distinct positions* in the sequence
-}

randomElems :: Ord a => S.Seq a -> Int -> IO (S.Seq a)
randomElems s n = do 
  g <- MWC.create
  randomElemsTR S.empty s g n

randomElemsTR :: Ord a => S.Seq a -> S.Seq a -> MWC.GenIO -> Int -> IO (S.Seq a)
randomElemsTR ixs s g n
    | n == S.length s = return $ S.unstableSort s
    | n == 1 = do (_,i) <- fmap fromJust (randomExtract s g)
                  return.S.unstableSort $ i S.<| ixs
    | otherwise = do (s',i) <- fmap fromJust (randomExtract s g)
                     (randomElemsTR $! (i S.<| ixs)) s' g (n-1)

{-|
  Chop a sequence at the given indices. 
  Assume number of indices given < length of sequence to be chopped
-}

pieces :: S.Seq a -> S.Seq Int -> [S.Seq a]
pieces s ixs = let f (ps,r,x) y = let (p,r') = S.splitAt (y-x) r
                                  in (p:ps,r',y)
                   g (a,b,_) = b:a
               in g $ F.foldl f ([],s,0) ixs

{-|
  Given n, chop a sequence at m random points
  where m = min (length-1, n-1)
-}

randomPieces :: Int -> S.Seq a -> IO [S.Seq a]
randomPieces n s
    | n >= l = return $ F.toList $ fmap S.singleton s
    | otherwise = do ixs <- randomElems (S.fromList [1..l-1]) (n-1)
                     return $ pieces s ixs
    where l = S.length s

{-|
  > pairs [1,2,3,4]
  [(1,2),(1,3),(1,4),(2,3),(2,4),(3,4)]
  > pairs [1,2,4,4]
  [(1,2),(1,4),(1,4),(2,4),(2,4),(4,4)]
-}

pairs :: [a] -> [(a,a)]
pairs [] = []
pairs (x:xs) = (zip (repeat x) xs) ++ pairs xs

l2Norm :: Floating a => [a] -> a
l2Norm l = sqrt.sum $ zipWith (*) l l