{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE UnboxedTuples #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
{-# OPTIONS_HADDOCK hide #-}

module Data.Array.Mutable.Linear.Internal
  ( -- * Mutable Linear Arrays
    Array (..),

    -- * Performing Computations with Arrays
    alloc,
    allocBeside,
    fromList,

    -- * Modifications
    set,
    unsafeSet,
    resize,
    map,

    -- * Accessors
    get,
    unsafeGet,
    size,
    slice,
    toList,
    freeze,

    -- * Mutable-style interface
    read,
    unsafeRead,
    write,
    unsafeWrite,
  )
where

import Data.Array.Mutable.Unlifted.Linear (Array#)
import qualified Data.Array.Mutable.Unlifted.Linear as Unlifted
import qualified Data.Functor.Linear as Data
import qualified Data.Primitive.Array as Prim
import Data.Unrestricted.Linear
import qualified Data.Vector as Vector
import GHC.Stack
import Prelude.Linear (forget, (&))
import Prelude hiding (map, read)

-- # Data types
-------------------------------------------------------------------------------

data Array a = Array (Array# a)

-- # Creation
-------------------------------------------------------------------------------

-- | Allocate a constant array given a size and an initial value
-- The size must be non-negative, otherwise this errors.
alloc ::
  HasCallStack =>
  Int ->
  a ->
  (Array a %1 -> Ur b) %1 ->
  Ur b
alloc :: forall a b.
HasCallStack =>
Int -> a -> (Array a %1 -> Ur b) %1 -> Ur b
alloc Int
s a
x Array a %1 -> Ur b
f
  | Int
s forall a. Ord a => a -> a -> Bool
< Int
0 =
      (forall a. HasCallStack => [Char] -> a
error ([Char]
"Array.alloc: negative size: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
s) :: x %1 -> x)
        (Array a %1 -> Ur b
f forall a. HasCallStack => a
undefined)
  | Bool
otherwise = forall a b. Int -> a -> (Array# a %1 -> Ur b) %1 -> Ur b
Unlifted.alloc Int
s a
x (\Array# a
arr -> Array a %1 -> Ur b
f (forall a. Array# a -> Array a
Array Array# a
arr))

-- | Allocate a constant array given a size and an initial value,
-- using another array as a uniqueness proof.
allocBeside :: Int -> a -> Array b %1 -> (Array a, Array b)
allocBeside :: forall a b. Int -> a -> Array b %1 -> (Array a, Array b)
allocBeside Int
s a
x (Array Array# b
orig)
  | Int
s forall a. Ord a => a -> a -> Bool
< Int
0 =
      forall a b. Array# a %1 -> b %1 -> b
Unlifted.lseq
        Array# b
orig
        (forall a. HasCallStack => [Char] -> a
error ([Char]
"Array.allocBeside: negative size: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
s))
  | Bool
otherwise =
      forall a b. (# Array# a, Array# b #) %1 -> (Array a, Array b)
wrap (forall a b. Int -> a -> Array# b %1 -> (# Array# a, Array# b #)
Unlifted.allocBeside Int
s a
x Array# b
orig)
  where
    wrap :: (# Array# a, Array# b #) %1 -> (Array a, Array b)
    wrap :: forall a b. (# Array# a, Array# b #) %1 -> (Array a, Array b)
wrap (# Array# a
orig, Array# b
new #) = (forall a. Array# a -> Array a
Array Array# a
orig, forall a. Array# a -> Array a
Array Array# b
new)

-- | Allocate an array from a list
fromList ::
  HasCallStack =>
  [a] ->
  (Array a %1 -> Ur b) %1 ->
  Ur b
fromList :: forall a b. HasCallStack => [a] -> (Array a %1 -> Ur b) %1 -> Ur b
fromList [a]
list (Array a %1 -> Ur b
f :: Array a %1 -> Ur b) =
  forall a b.
HasCallStack =>
Int -> a -> (Array a %1 -> Ur b) %1 -> Ur b
alloc
    (forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [a]
list)
    (forall a. HasCallStack => [Char] -> a
error [Char]
"invariant violation: unintialized array position")
    (\Array a
arr -> Array a %1 -> Ur b
f (Array a %1 -> Array a
insert Array a
arr))
  where
    insert :: Array a %1 -> Array a
    insert :: Array a %1 -> Array a
insert = [(a, Int)] -> Array a %1 -> Array a
doWrites (forall a b. [a] -> [b] -> [(a, b)]
zip [a]
list [Int
0 ..])

    doWrites :: [(a, Int)] -> Array a %1 -> Array a
    doWrites :: [(a, Int)] -> Array a %1 -> Array a
doWrites [] Array a
arr = Array a
arr
    doWrites ((a
a, Int
ix) : [(a, Int)]
xs) Array a
arr = [(a, Int)] -> Array a %1 -> Array a
doWrites [(a, Int)]
xs (forall a. Int -> a -> Array a %1 -> Array a
unsafeSet Int
ix a
a Array a
arr)

-- # Mutations and Reads
-------------------------------------------------------------------------------

size :: Array a %1 -> (Ur Int, Array a)
size :: forall a. Array a %1 -> (Ur Int, Array a)
size (Array Array# a
arr) = forall a. (# Ur Int, Array# a #) %1 -> (Ur Int, Array a)
f (forall a. Array# a %1 -> (# Ur Int, Array# a #)
Unlifted.size Array# a
arr)
  where
    f :: (# Ur Int, Array# a #) %1 -> (Ur Int, Array a)
    f :: forall a. (# Ur Int, Array# a #) %1 -> (Ur Int, Array a)
f (# Ur Int
s, Array# a
arr #) = (Ur Int
s, forall a. Array# a -> Array a
Array Array# a
arr)

-- | Sets the value of an index. The index should be less than the arrays
-- size, otherwise this errors.
set :: HasCallStack => Int -> a -> Array a %1 -> Array a
set :: forall a. HasCallStack => Int -> a -> Array a %1 -> Array a
set Int
i a
x Array a
arr = forall a. Int -> a -> Array a %1 -> Array a
unsafeSet Int
i a
x (forall a. HasCallStack => Int -> Array a %1 -> Array a
assertIndexInRange Int
i Array a
arr)

-- | Same as 'set', but does not do bounds-checking. The behaviour is undefined
-- if an out-of-bounds index is provided.
unsafeSet :: Int -> a -> Array a %1 -> Array a
unsafeSet :: forall a. Int -> a -> Array a %1 -> Array a
unsafeSet Int
ix a
val (Array Array# a
arr) =
  forall a. Array# a -> Array a
Array (forall a. Int -> a -> Array# a %1 -> Array# a
Unlifted.set Int
ix a
val Array# a
arr)

-- | Get the value of an index. The index should be less than the arrays 'size',
-- otherwise this errors.
get :: HasCallStack => Int -> Array a %1 -> (Ur a, Array a)
get :: forall a. HasCallStack => Int -> Array a %1 -> (Ur a, Array a)
get Int
i Array a
arr = forall a. Int -> Array a %1 -> (Ur a, Array a)
unsafeGet Int
i (forall a. HasCallStack => Int -> Array a %1 -> Array a
assertIndexInRange Int
i Array a
arr)

-- | Same as 'get', but does not do bounds-checking. The behaviour is undefined
-- if an out-of-bounds index is provided.
unsafeGet :: Int -> Array a %1 -> (Ur a, Array a)
unsafeGet :: forall a. Int -> Array a %1 -> (Ur a, Array a)
unsafeGet Int
ix (Array Array# a
arr) = forall a. (# Ur a, Array# a #) %1 -> (Ur a, Array a)
wrap (forall a. Int -> Array# a %1 -> (# Ur a, Array# a #)
Unlifted.get Int
ix Array# a
arr)
  where
    wrap :: (# Ur a, Array# a #) %1 -> (Ur a, Array a)
    wrap :: forall a. (# Ur a, Array# a #) %1 -> (Ur a, Array a)
wrap (# Ur a
ret, Array# a
arr #) = (Ur a
ret, forall a. Array# a -> Array a
Array Array# a
arr)

-- | Resize an array. That is, given an array, a target size, and a seed
-- value; resize the array to the given size using the seed value to fill
-- in the new cells when necessary and copying over all the unchanged cells.
--
-- Target size should be non-negative.
--
-- @
-- let b = resize n x a,
--   then size b = n,
--   and b[i] = a[i] for i < size a,
--   and b[i] = x for size a <= i < n.
-- @
resize :: HasCallStack => Int -> a -> Array a %1 -> Array a
resize :: forall a. HasCallStack => Int -> a -> Array a %1 -> Array a
resize Int
newSize a
seed (Array Array# a
arr :: Array a)
  | Int
newSize forall a. Ord a => a -> a -> Bool
< Int
0 =
      forall a b. Array# a %1 -> b %1 -> b
Unlifted.lseq
        Array# a
arr
        (forall a. HasCallStack => [Char] -> a
error [Char]
"Trying to resize to a negative size.")
  | Bool
otherwise =
      (# Array# a, Array# a #) %1 -> Array a
doCopy (forall a b. Int -> a -> Array# b %1 -> (# Array# a, Array# b #)
Unlifted.allocBeside Int
newSize a
seed Array# a
arr)
  where
    doCopy :: (# Array# a, Array# a #) %1 -> Array a
    doCopy :: (# Array# a, Array# a #) %1 -> Array a
doCopy (# Array# a
new, Array# a
old #) = (# Array# a, Array# a #) %1 -> Array a
wrap (forall a.
Int -> Array# a %1 -> Array# a %1 -> (# Array# a, Array# a #)
Unlifted.copyInto Int
0 Array# a
old Array# a
new)

    wrap :: (# Array# a, Array# a #) %1 -> Array a
    wrap :: (# Array# a, Array# a #) %1 -> Array a
wrap (# Array# a
src, Array# a
dst #) = Array# a
src forall a b. Array# a %1 -> b %1 -> b
`Unlifted.lseq` forall a. Array# a -> Array a
Array Array# a
dst

-- | Return the array elements as a lazy list.
toList :: Array a %1 -> Ur [a]
toList :: forall a. Array a %1 -> Ur [a]
toList (Array Array# a
arr) = forall a. Array# a %1 -> Ur [a]
Unlifted.toList Array# a
arr

-- | Copy a slice of the array, starting from given offset and copying given
-- number of elements. Returns the pair (oldArray, slice).
--
-- Start offset + target size should be within the input array, and both should
-- be non-negative.
--
-- @
-- let b = slice i n a,
--   then size b = n,
--   and b[j] = a[i+j] for 0 <= j < n
-- @
slice ::
  HasCallStack =>
  -- | Start offset
  Int ->
  -- | Target size
  Int ->
  Array a %1 ->
  (Array a, Array a)
slice :: forall a.
HasCallStack =>
Int -> Int -> Array a %1 -> (Array a, Array a)
slice Int
from Int
targetSize Array a
arr =
  forall a. Array a %1 -> (Ur Int, Array a)
size Array a
arr forall a b (p :: Multiplicity) (q :: Multiplicity).
a %p -> (a %p -> b) %q -> b
& \case
    (Ur Int
s, Array Array# a
old)
      | Int
s forall a. Ord a => a -> a -> Bool
< Int
from forall a. Num a => a -> a -> a
+ Int
targetSize ->
          forall a b. Array# a %1 -> b %1 -> b
Unlifted.lseq
            Array# a
old
            (forall a. HasCallStack => [Char] -> a
error [Char]
"Slice index out of bounds.")
      | Bool
otherwise ->
          forall a. (# Array# a, Array# a #) %1 -> (Array a, Array a)
doCopy
            ( forall a b. Int -> a -> Array# b %1 -> (# Array# a, Array# b #)
Unlifted.allocBeside
                Int
targetSize
                (forall a. HasCallStack => [Char] -> a
error [Char]
"invariant violation: uninitialized array index")
                Array# a
old
            )
  where
    doCopy :: (# Array# a, Array# a #) %1 -> (Array a, Array a)
    doCopy :: forall a. (# Array# a, Array# a #) %1 -> (Array a, Array a)
doCopy (# Array# a
new, Array# a
old #) = forall a. (# Array# a, Array# a #) %1 -> (Array a, Array a)
wrap (forall a.
Int -> Array# a %1 -> Array# a %1 -> (# Array# a, Array# a #)
Unlifted.copyInto Int
from Array# a
old Array# a
new)

    wrap :: (# Array# a, Array# a #) %1 -> (Array a, Array a)
    wrap :: forall a. (# Array# a, Array# a #) %1 -> (Array a, Array a)
wrap (# Array# a
old, Array# a
new #) = (forall a. Array# a -> Array a
Array Array# a
old, forall a. Array# a -> Array a
Array Array# a
new)

-- | /O(1)/ Convert an 'Array' to an immutable 'Vector.Vector' (from
-- 'vector' package).
freeze :: Array a %1 -> Ur (Vector.Vector a)
freeze :: forall a. Array a %1 -> Ur (Vector a)
freeze (Array Array# a
arr) =
  forall a b. (Array# a -> b) -> Array# a %1 -> Ur b
Unlifted.freeze (\Array# a
a -> forall a. Array a -> Vector a
Vector.fromArray (forall a. Array# a -> Array a
Prim.Array Array# a
a)) Array# a
arr

map :: (a -> b) -> Array a %1 -> Array b
map :: forall a b. (a -> b) -> Array a %1 -> Array b
map a -> b
f (Array Array# a
arr) = forall a. Array# a -> Array a
Array (forall a b. (a -> b) -> Array# a %1 -> Array# b
Unlifted.map a -> b
f Array# a
arr)

-- # Mutation-style API
-------------------------------------------------------------------------------

-- | Same as 'set', but takes the 'Array' as the first parameter.
write :: HasCallStack => Array a %1 -> Int -> a -> Array a
write :: forall a. HasCallStack => Array a %1 -> Int -> a -> Array a
write Array a
arr Int
i a
a = forall a. HasCallStack => Int -> a -> Array a %1 -> Array a
set Int
i a
a Array a
arr

-- | Same as 'unsafeSet', but takes the 'Array' as the first parameter.
unsafeWrite :: Array a %1 -> Int -> a -> Array a
unsafeWrite :: forall a. Array a %1 -> Int -> a -> Array a
unsafeWrite Array a
arr Int
i a
a = forall a. Int -> a -> Array a %1 -> Array a
unsafeSet Int
i a
a Array a
arr

-- | Same as 'get', but takes the 'Array' as the first parameter.
read :: HasCallStack => Array a %1 -> Int -> (Ur a, Array a)
read :: forall a. HasCallStack => Array a %1 -> Int -> (Ur a, Array a)
read Array a
arr Int
i = forall a. HasCallStack => Int -> Array a %1 -> (Ur a, Array a)
get Int
i Array a
arr

-- | Same as 'unsafeGet', but takes the 'Array' as the first parameter.
unsafeRead :: Array a %1 -> Int -> (Ur a, Array a)
unsafeRead :: forall a. Array a %1 -> Int -> (Ur a, Array a)
unsafeRead Array a
arr Int
i = forall a. Int -> Array a %1 -> (Ur a, Array a)
unsafeGet Int
i Array a
arr

-- # Instances
-------------------------------------------------------------------------------

instance Consumable (Array a) where
  consume :: Array a %1 -> ()
  consume :: Array a %1 -> ()
consume (Array Array# a
arr) = Array# a
arr forall a b. Array# a %1 -> b %1 -> b
`Unlifted.lseq` ()

instance Dupable (Array a) where
  dup2 :: Array a %1 -> (Array a, Array a)
  dup2 :: Array a %1 -> (Array a, Array a)
dup2 (Array Array# a
arr) = (# Array# a, Array# a #) %1 -> (Array a, Array a)
wrap (forall a. Array# a %1 -> (# Array# a, Array# a #)
Unlifted.dup2 Array# a
arr)
    where
      wrap :: (# Array# a, Array# a #) %1 -> (Array a, Array a)
      wrap :: (# Array# a, Array# a #) %1 -> (Array a, Array a)
wrap (# Array# a
a1, Array# a
a2 #) = (forall a. Array# a -> Array a
Array Array# a
a1, forall a. Array# a -> Array a
Array Array# a
a2)

instance Data.Functor Array where
  fmap :: forall a b. (a %1 -> b) -> Array a %1 -> Array b
fmap a %1 -> b
f Array a
arr = forall a b. (a -> b) -> Array a %1 -> Array b
map (forall a b. (a %1 -> b) %1 -> a -> b
forget a %1 -> b
f) Array a
arr

-- # Internal library
-------------------------------------------------------------------------------

-- | Check if given index is within the Array, otherwise panic.
assertIndexInRange :: HasCallStack => Int -> Array a %1 -> Array a
assertIndexInRange :: forall a. HasCallStack => Int -> Array a %1 -> Array a
assertIndexInRange Int
i Array a
arr =
  forall a. Array a %1 -> (Ur Int, Array a)
size Array a
arr forall a b (p :: Multiplicity) (q :: Multiplicity).
a %p -> (a %p -> b) %q -> b
& \(Ur Int
s, Array a
arr') ->
    if Int
0 forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i forall a. Ord a => a -> a -> Bool
< Int
s
      then Array a
arr'
      else Array a
arr' forall a b. Consumable a => a %1 -> b %1 -> b
`lseq` forall a. HasCallStack => [Char] -> a
error [Char]
"Array: index out of bounds"