{-# LANGUAGE CPP
           , NoImplicitPrelude
           , UnicodeSyntax
           , BangPatterns
           , ScopedTypeVariables
  #-}

module Utils where

--------------------------------------------------------------------------------
-- Imports
--------------------------------------------------------------------------------

-- from base:
import Prelude ( ($)
               , Num, (+), (*), (-)
               , Enum, toEnum, fromEnum
               , Integral, fromIntegral, undefined
               )

#if __GLASGOW_HASKELL__ < 700
import Prelude               ( fromInteger )
#endif

import Control.Monad         ( Monad, return, (>>=), (>>) )
import Foreign.Ptr           ( Ptr )
import Foreign.ForeignPtr    ( withForeignPtr )
import Foreign.Storable      ( Storable, peek, sizeOf )
import Foreign.Marshal.Alloc ( alloca )
import Foreign.Marshal.Utils ( copyBytes )
import Data.Bool             ( Bool, otherwise )
import Data.Ord              ( Ord, (>) )
import Data.Bits             ( Bits, shiftL, shiftR, bitSize, (.&.) )
import Data.Int              ( Int )
import Data.Maybe            ( Maybe(Nothing, Just) )
import System.IO             ( IO )
import GHC.ForeignPtr        ( mallocPlainForeignPtrBytes )

-- from vector:
import           Data.Vector                ( Vector )
import qualified Data.Vector          as V  ( null, unsafeHead, unsafeTail )
import qualified Data.Vector.Storable as VS ( Vector, empty, null
                                            , unsafeFromForeignPtr0
                                            , unsafeToForeignPtr0
                                            )
import qualified Data.Vector.Generic  as VG ( Vector, mapM, convert )

-- from base-unicode-symbols:
import Data.Function.Unicode ( () )
import Data.Ord.Unicode      ( (), () )
import Data.Bool.Unicode     ( () )


--------------------------------------------------------------------------------
-- Utils
--------------------------------------------------------------------------------

-- | @bits s e b@ extract bit @s@ to @e@ (including) from @b@.
bits  (Bits α, Num α)  Int  Int  α  α
bits s e b = ((1 `shiftL` (e - s + 1)) - 1) .&. (b `shiftR` s)

-- | @between n b e@ tests if @n@ is between the given bounds @b@ and @e@
-- (including).
between  Ord α  α  α  α  Bool
between n b e = n  b  n  e

-- | A generalized 'toEnum' that works on any 'Integral' type.
genToEnum  (Integral i, Enum e)  i  e
genToEnum = toEnum  fromIntegral

-- | A generalized 'fromEnum' that returns any 'Integral' type.
genFromEnum  (Integral i, Enum e)  e  i
genFromEnum = fromIntegral  fromEnum

-- | @mapPeekArray f n a@ applies the monadic function @f@ to each of the @n@
-- elements of the array @a@ and returns the results in a list.
mapPeekArray  (Storable a, VG.Vector v a, VG.Vector v b)  (a  IO b)  Int  Ptr a  IO (v b)
mapPeekArray f n a = peekVector n a >>= VG.mapM f  VG.convert

peekVector  forall a. (Storable a)  Int  Ptr a  IO (VS.Vector a)
peekVector size ptr
    | size  0  = return VS.empty
    | otherwise = do
        let n = (size * sizeOf (undefined  a))
        fp  mallocPlainForeignPtrBytes n
        withForeignPtr fp $ \p  copyBytes p ptr n
        return $ VS.unsafeFromForeignPtr0 fp size

pokeVector  forall a. Storable a  Ptr a  VS.Vector a  IO ()
pokeVector ptr v | VS.null v = return ()
                 | otherwise = withForeignPtr fp $ \p 
                     copyBytes ptr p (size * sizeOf (undefined  a))
    where
      (fp, size) = VS.unsafeToForeignPtr0 v

allocaPeek  Storable α  (Ptr α  IO ())  IO α
allocaPeek f = alloca $ \ptr  f ptr >> peek ptr

-- | Monadic if...then...else...
ifM  Monad m  m Bool  m α  m α  m α
ifM cM tM eM = cM >>= \c  if c then tM else eM

{-| @decodeBCD bitsInDigit bcd@ decodes the Binary Coded Decimal @bcd@ to a list
of its encoded digits. @bitsInDigit@, which is usually 4, is the number of bits
used to encode a single digit. See:
<http://en.wikipedia.org/wiki/Binary-coded_decimal>
-}
decodeBCD  Bits α  Int  α  [α]
decodeBCD bitsInDigit abcd = go 0
    where
      shftR = bitSize abcd - bitsInDigit

      go !shftL | shftL > shftR = []
                | otherwise     = let !d = (abcd `shiftL` shftL) `shiftR` shftR
                                  in d : go (shftL + bitsInDigit)

uncons  Vector α  Maybe (α, Vector α)
uncons v | V.null v  = Nothing
         | otherwise = Just (V.unsafeHead v, V.unsafeTail v)