{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnliftedFFITypes #-}
module Std.Data.PrimArray.BitTwiddle where
import GHC.Prim
import GHC.Types
import GHC.Word
import Data.Primitive.PrimArray
#include "MachDeps.h"
#if SIZEOF_HSWORD == 4
# define CAST_OFFSET_WORD_TO_BYTE(x) (x `uncheckedIShiftL#` 2#)
# define CAST_OFFSET_BYTE_TO_WORD(x) (x `uncheckedIShiftRA#` 2#)
#else
# define CAST_OFFSET_WORD_TO_BYTE(x) (x `uncheckedIShiftL#` 3#)
# define CAST_OFFSET_BYTE_TO_WORD(x) (x `uncheckedIShiftRA#` 3#)
#endif
isOffsetAligned# :: Int# -> Bool
{-# INLINE isOffsetAligned# #-}
isOffsetAligned# s# = isTrue# ((SIZEOF_HSWORD# -# 1#) `andI#` s# ==# 0#)
mkMask# :: Word# -> Word#
{-# INLINE mkMask# #-}
mkMask# w8# =
#if SIZEOF_HSWORD == 4
    let w16# = w8# `or#` (w8# `uncheckedShiftL#` 8#)
    in w16# `or#` (w16# `uncheckedShiftL#` 16#)
#else
    let w16# = w8# `or#` (w8# `uncheckedShiftL#` 8#)
        w32# = w16# `or#` (w16# `uncheckedShiftL#` 16#)
    in w32# `or#` (w32# `uncheckedShiftL#` 32#)
#endif
nullByteMagic# :: Word# -> Word#
{-# INLINE nullByteMagic# #-}
nullByteMagic# w# =
#if SIZEOF_HSWORD == 4
    (w# `minusWord#` 0x01010101##) `and#` (not# w#) `and#` 0x80808080##
#else
    (w# `minusWord#` 0x0101010101010101##) `and#` (not# w#) `and#` 0x8080808080808080##
#endif
memchr :: PrimArray Word8 
       -> Word8           
       -> Int             
       -> Int             
       -> Int
{-# INLINE memchr #-}
memchr (PrimArray ba#) (W8# c#) (I# s#) (I# siz#) =
    I# (memchr# ba# c# s# siz#)
memchr# :: ByteArray# -> Word# -> Int# -> Int# -> Int#
{-# NOINLINE memchr# #-}
memchr# ba# c# s# siz# = beforeAlignedLoop# ba# c# s# (s# +# siz#)
  where
    beforeAlignedLoop# :: ByteArray# -> Word# -> Int# -> Int# -> Int#
    beforeAlignedLoop# ba# c# s# end#
        | isTrue# (s# >=# end#) = -1#
        | isTrue# (c# `eqWord#` indexWord8Array# ba# s#) = s#
        | isOffsetAligned# s# = alignedLoop# ba# (mkMask# c#)
                                           CAST_OFFSET_BYTE_TO_WORD(s#)
                                           CAST_OFFSET_BYTE_TO_WORD(end#)
                                           end#
        | otherwise = beforeAlignedLoop# ba# c# (s# +# 1#) end#
    alignedLoop# :: ByteArray# -> Word# -> Int# -> Int# -> Int# -> Int#
    alignedLoop# ba# mask# s# end# end_#
        | isTrue# (s# >=# end#) = afterAlignedLoop# ba# (mask# `and#` 0xFF##)
                                                    CAST_OFFSET_WORD_TO_BYTE(s#)
                                                    end_#
        | otherwise = case indexWordArray# ba# s# of
            w# ->
                case nullByteMagic# (mask# `xor#` w#) of
                    0## -> alignedLoop# ba# mask# (s# +# 1#) end# end_#
                    _   -> afterAlignedLoop# ba# (mask# `and#` 0xFF##)
                                             CAST_OFFSET_WORD_TO_BYTE(s#)
                                             end_#
    afterAlignedLoop# :: ByteArray# -> Word# -> Int# -> Int# -> Int#
    afterAlignedLoop# ba# c# s# end#
        | isTrue# (s# >=# end#) = -1#
        | isTrue# (c# `eqWord#` indexWord8Array# ba# s#) = s#
        | otherwise = afterAlignedLoop# ba# c# (s# +# 1#) end#
memchrReverse :: PrimArray Word8  
              -> Word8            
              -> Int              
              -> Int              
              -> Int
{-# INLINE memchrReverse #-}
memchrReverse (PrimArray ba#) (W8# c#) (I# s#) (I# siz#) =
    I# (memchr# ba# c# s# siz#)
memchrReverse# :: ByteArray# -> Word# -> Int# -> Int# -> Int#
{-# NOINLINE memchrReverse# #-}
memchrReverse# ba# c# s# siz# = beforeAlignedLoop# ba# c# s# (s# -# siz#)
  where
    beforeAlignedLoop# :: ByteArray# -> Word# -> Int# -> Int# -> Int#
    beforeAlignedLoop# ba# c# s# end#
        | isTrue# (s# <# end#) = -1#
        | isTrue# (c# `eqWord#` indexWord8Array# ba# s#) = s#
        | isOffsetAligned# s# = alignedLoop# ba# (mkMask# c#)
                                           CAST_OFFSET_BYTE_TO_WORD(s#)
                                           CAST_OFFSET_BYTE_TO_WORD(end#)
                                           end#
        | otherwise = beforeAlignedLoop# ba# c# (s# -# 1#) end#
    alignedLoop# :: ByteArray# -> Word# -> Int# -> Int# -> Int# -> Int#
    alignedLoop# ba# mask# s# end# end_#
        | isTrue# (s# <# end#) = afterAlignedLoop# ba# (mask# `and#` 0xFF##)
                                                   CAST_OFFSET_WORD_TO_BYTE(s#)
                                                   end_#
        | otherwise = case indexWordArray# ba# s# of
            w# ->
                case nullByteMagic# (mask# `xor#` w#) of
                    0## -> alignedLoop# ba# mask# (s# -# 1#) end# end_#
                    _   -> afterAlignedLoop# ba# (mask# `and#` 0xFF##)
                                             CAST_OFFSET_WORD_TO_BYTE(s#)
                                             end_#
    afterAlignedLoop# :: ByteArray# -> Word# -> Int# -> Int# -> Int#
    afterAlignedLoop# ba# c# s# end#
        | isTrue# (s# <# end#) = -1#
        | isTrue# (c# `eqWord#` indexWord8Array# ba# s#) = s#
        | otherwise = afterAlignedLoop# ba# c# (s# -# 1#) end#
foreign import ccall unsafe "hs_memchr" c_memchr ::
    ByteArray# -> Int -> Word8 -> Int -> Int