{-# LANGUAGE RankNTypes, BangPatterns, MagicHash #-}

-- | A compilation of minor array combinators used extensively in "Data.RangeMin".
module Data.RangeMin.Internal.HandyArray (unsafeMemoize#, asPureArray, pureLookup, listLookup') where

import Data.RangeMin.Internal.HandyList()
import Control.Monad(Monad(..))
import Control.Monad.ST(ST, runST)
import Data.Array.ST(STArray)
import Data.Array.IArray(Ix(..), Array)
import Data.Array.Base
import GHC.Exts(Int#, Int(..), (+#), (-#), (==#))

-- | asPureArray is a simple syntax to force an ambiguous array to a standard 'Array'. 
asPureArray :: Ix i => Array i e -> Array i e
asPureArray = id

{-# INLINE listToArray #-}
-- | @listToArray@ converts a list to an arbitrary array type.
listToArray :: IArray a e => [e] -- ^ A list of arbitrary elements.
				-> a Int e -- ^ A zero-indexed array containing precisely the elements of the list.
listToArray list = listArray (0, length list - 1) list

{-# INLINE arraySize #-}
-- | Shorthand for the size of an array.
arraySize :: (Ix i, IArray a e) => a i e -> Int
arraySize = rangeSize . bounds

-- used for both coercing an untyped array to a standard Array, and simultaneously returning the lookup function
-- of that array
{-# INLINE pureLookup #-}
-- | A lookup function that also forces its array argument to an 'Array'.
pureLookup :: Ix i => Array i e -> i -> e
pureLookup = (!)

{-# INLINE pureUnsafeLookup #-}
-- | An unsafe lookup function on standard 'Array' types that does not range-check its argument.
pureUnsafeLookup :: Ix i => Array i e -> Int -> e
pureUnsafeLookup = unsafeAt

-- {-# INLINE unsafeMemoize #-}
-- | A memoization function.
-- unsafeMemoize :: (Int -> e) -- ^ An arbitrary function on integer values.
-- 		-> Int		-- ^ n
-- 		-> (Int -> e)	-- ^ A function on integers from @0@ to @n-1@ that memoizes its values.
-- unsafeMemoize f (n+1) = unsafeMemoize' f n

{-# INLINE unsafeMemoize# #-}
-- | A memoization function that indexes to @n@ inclusive.
unsafeMemoize# :: (Int# -> e) -- ^ An arbitrary function on integer values.
		-> Int -- ^ n
		-> (Int# -> e)	-- ^ A function on integers from @0@ to @n@ that memoizes its values.
unsafeMemoize# f n = {-# SCC "memoization" #-} (\ i# -> listLookerUpper (memoizer f n) (n+1) (I# i#))

{-# INLINE memoizer #-}
memoizer :: (Int# -> e) -> Int -> STArray s Int e -> ST s ()
memoizer f (I# n#) arr = memoizer' 0# where
	memoizer' i# = let write = unsafeWrite arr (I# i#) (f i#) in if i# ==# n# then write else memoizer' (i# +# 1#) >> write

{-# INLINE listLookerUpper #-}
listLookerUpper :: (forall s . STArray s Int e -> ST s ()) -> Int -> (Int -> e)
listLookerUpper f n = pureUnsafeLookup $ runST $ newArray_ (0, n-1) >>= \arr -> f arr >> unsafeFreeze arr

data Acc s = A Int# {-# UNPACK #-} !(ST s ())
{-# INLINE listLookup' #-}
-- | Given a list and its length, memoizes lookups on the list.  An attempt to process a list longer than
-- the specified size will result in a segfault.
listLookup' :: Int -> [e] -> (Int -> e)
listLookup' n@(I# n#) l = {-# CORE "list_memoization" #-} listLookerUpper
	(\ arr -> case foldr (acc arr) (A n# (return ())) l of A _ ans -> ans) n where
	acc arr x (A i# m) = let j# = i# -# 1# in A j# (unsafeWrite arr (I# j#) x >> m)