{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE RankNTypes #-}

-- | This module provides push arrays.
--
-- These are part of a larger framework for controlling when memory is
-- allocated for an array. See @Data.Array.Polarized@.
--
-- This module is designed to be imported qualified as @Push@.
module Data.Array.Polarized.Push
  (
  -- * Construction
    Array(..)
  , make
  , singleton
  , cons
  , snoc
  -- * Operations
  , alloc
  , foldMap
  , unzip
  )
where

import qualified Data.Functor.Linear as Data
import qualified Data.Array.Destination as DArray
import Data.Array.Destination (DArray)
import Data.Vector (Vector)
import qualified Prelude
import Prelude.Linear hiding (unzip, foldMap)
import GHC.Stack


-- The Types
-------------------------------------------------------------------------------

-- | Push arrays are un-allocated finished arrays. These are finished
-- computations passed along or enlarged until we are ready to allocate.
data Array a where
  Array :: (forall m. Monoid m => (a -> m) -> m) %1-> Array a
  -- Developer notes:
  --
  -- Think of @(a -> m)@ as something that writes an @a@ and think of
  -- @((a -> m) -> m)@ as something that takes a way to write a single element
  -- and writes and concatenates all elements.
  --
  -- Also, note that in this formulation we don't know the length beforehand.

data ArrayWriter a where
  ArrayWriter :: (DArray a %1-> ()) %1-> !Int -> ArrayWriter a
  -- The second parameter is the length of the @DArray@


-- API
-------------------------------------------------------------------------------

-- | Convert a push array into a vector by allocating. This would be a common
-- end to a computation using push and pull arrays.
alloc :: Array a %1-> Vector a
alloc :: forall a. Array a %1 -> Vector a
alloc (Array forall m. Monoid m => (a -> m) -> m
k) = ArrayWriter a %1 -> Vector a
forall a. ArrayWriter a %1 -> Vector a
allocArrayWriter (ArrayWriter a %1 -> Vector a) %1 -> ArrayWriter a %1 -> Vector a
forall a b. (a %1 -> b) %1 -> a %1 -> b
$ (a -> ArrayWriter a) -> ArrayWriter a
forall m. Monoid m => (a -> m) -> m
k a -> ArrayWriter a
forall a. a -> ArrayWriter a
singletonWriter where
  singletonWriter :: a -> ArrayWriter a
  singletonWriter :: forall a. a -> ArrayWriter a
singletonWriter a
a = (DArray a %1 -> ()) -> Int -> ArrayWriter a
forall a. (DArray a %1 -> ()) -> Int -> ArrayWriter a
ArrayWriter (a %1 -> DArray a %1 -> ()
forall a. HasCallStack => a %1 -> DArray a %1 -> ()
DArray.fill a
a) Int
1

  allocArrayWriter :: ArrayWriter a %1-> Vector a
  allocArrayWriter :: forall a. ArrayWriter a %1 -> Vector a
allocArrayWriter (ArrayWriter DArray a %1 -> ()
writer Int
len) = Int -> (DArray a %1 -> ()) %1 -> Vector a
forall a. Int -> (DArray a %1 -> ()) %1 -> Vector a
DArray.alloc Int
len DArray a %1 -> ()
writer

-- | @`make` x n@ creates a constant push array of length @n@ in which every
-- element is @x@.
make :: HasCallStack => a -> Int -> Array a
make :: forall a. HasCallStack => a -> Int -> Array a
make a
x Int
n
  | Int
n Int %1 -> Int %1 -> Bool
forall a. Ord a => a %1 -> a %1 -> Bool
< Int
0 = [Char] -> Array a
forall a. HasCallStack => [Char] -> a
error [Char]
"Making a negative length push array"
  | Bool
otherwise = (forall m. Monoid m => (a -> m) -> m) -> Array a
forall a. (forall m. Monoid m => (a -> m) -> m) -> Array a
Array (\a -> m
makeA -> [m] %1 -> m
forall a. Monoid a => [a] %1 -> a
mconcat ([m] %1 -> m) %1 -> [m] %1 -> m
forall a b. (a %1 -> b) %1 -> a %1 -> b
$ Int -> m -> [m]
forall a. Int -> a -> [a]
Prelude.replicate Int
n (a -> m
makeA a
x))

singleton :: a -> Array a
singleton :: forall a. a -> Array a
singleton a
x = (forall m. Monoid m => (a -> m) -> m) -> Array a
forall a. (forall m. Monoid m => (a -> m) -> m) -> Array a
Array (\a -> m
writeA -> a -> m
writeA a
x)

snoc :: a -> Array a %1-> Array a
snoc :: forall a. a -> Array a %1 -> Array a
snoc a
x (Array forall m. Monoid m => (a -> m) -> m
k) = (forall m. Monoid m => (a -> m) -> m) %1 -> Array a
forall a. (forall m. Monoid m => (a -> m) -> m) -> Array a
Array (\a -> m
writeA -> ((a -> m) -> m
forall m. Monoid m => (a -> m) -> m
k a -> m
writeA) m %1 -> m %1 -> m
forall a. Semigroup a => a %1 -> a %1 -> a
<> (a -> m
writeA a
x))

cons :: a -> Array a %1-> Array a
cons :: forall a. a -> Array a %1 -> Array a
cons a
x (Array forall m. Monoid m => (a -> m) -> m
k) = (forall m. Monoid m => (a -> m) -> m) %1 -> Array a
forall a. (forall m. Monoid m => (a -> m) -> m) -> Array a
Array (\a -> m
writeA -> (a -> m
writeA a
x) m %1 -> m %1 -> m
forall a. Semigroup a => a %1 -> a %1 -> a
<> ((a -> m) -> m
forall m. Monoid m => (a -> m) -> m
k a -> m
writeA))

foldMap :: Monoid b => (a -> b) -> Array a %1-> b
foldMap :: forall b a. Monoid b => (a -> b) -> Array a %1 -> b
foldMap a -> b
f (Array forall m. Monoid m => (a -> m) -> m
k) = (a -> b) -> b
forall m. Monoid m => (a -> m) -> m
k a -> b
f

unzip :: Array (a,b) %1-> (Array a, Array b)
unzip :: forall a b. Array (a, b) %1 -> (Array a, Array b)
unzip (Array forall m. Monoid m => ((a, b) -> m) -> m
k) = ((a, b) -> (Array a, Array b)) -> (Array a, Array b)
forall m. Monoid m => ((a, b) -> m) -> m
k (\(a
a,b
b) -> (a -> Array a
forall a. a -> Array a
singleton a
a, b -> Array b
forall a. a -> Array a
singleton b
b))


-- # Instances
-------------------------------------------------------------------------------

instance Data.Functor Array where
  fmap :: forall a b. (a %1 -> b) -> Array a %1 -> Array b
fmap a %1 -> b
f (Array forall m. Monoid m => (a -> m) -> m
k) = (forall m. Monoid m => (b -> m) -> m) %1 -> Array b
forall a. (forall m. Monoid m => (a -> m) -> m) -> Array a
Array (\b -> m
g -> (a -> m) -> m
forall m. Monoid m => (a -> m) -> m
k (\a
x -> (b -> m
g (a %1 -> b
f a
x))))

instance Prelude.Semigroup (Array a) where
  <> :: Array a -> Array a -> Array a
(<>) Array a
x Array a
y = Array a %1 -> Array a %1 -> Array a
forall a. Array a %1 -> Array a %1 -> Array a
append Array a
x Array a
y

instance Semigroup (Array a) where
  <> :: Array a %1 -> Array a %1 -> Array a
(<>) = Array a %1 -> Array a %1 -> Array a
forall a. Array a %1 -> Array a %1 -> Array a
append

instance Prelude.Monoid (Array a) where
  mempty :: Array a
mempty = Array a
forall a. Array a
empty

instance Monoid (Array a) where
  mempty :: Array a
mempty = Array a
forall a. Array a
empty

empty :: Array a
empty :: forall a. Array a
empty = (forall m. Monoid m => (a -> m) -> m) -> Array a
forall a. (forall m. Monoid m => (a -> m) -> m) -> Array a
Array (\a -> m
_ -> m
forall a. Monoid a => a
mempty)

append :: Array a %1-> Array a %1-> Array a
append :: forall a. Array a %1 -> Array a %1 -> Array a
append (Array forall m. Monoid m => (a -> m) -> m
k1) (Array forall m. Monoid m => (a -> m) -> m
k2) = (forall m. Monoid m => (a -> m) -> m) %1 -> Array a
forall a. (forall m. Monoid m => (a -> m) -> m) -> Array a
Array (\a -> m
writeA -> (a -> m) -> m
forall m. Monoid m => (a -> m) -> m
k1 a -> m
writeA m %1 -> m %1 -> m
forall a. Semigroup a => a %1 -> a %1 -> a
<> (a -> m) -> m
forall m. Monoid m => (a -> m) -> m
k2 a -> m
writeA)

instance Prelude.Semigroup (ArrayWriter a) where
  <> :: ArrayWriter a -> ArrayWriter a -> ArrayWriter a
(<>) ArrayWriter a
x ArrayWriter a
y = ArrayWriter a %1 -> ArrayWriter a %1 -> ArrayWriter a
forall a. ArrayWriter a %1 -> ArrayWriter a %1 -> ArrayWriter a
addWriters ArrayWriter a
x ArrayWriter a
y

instance Prelude.Monoid (ArrayWriter a) where
  mempty :: ArrayWriter a
mempty = ArrayWriter a
forall a. ArrayWriter a
emptyWriter

instance Semigroup (ArrayWriter a) where
  <> :: ArrayWriter a %1 -> ArrayWriter a %1 -> ArrayWriter a
(<>) = ArrayWriter a %1 -> ArrayWriter a %1 -> ArrayWriter a
forall a. ArrayWriter a %1 -> ArrayWriter a %1 -> ArrayWriter a
addWriters

instance Monoid (ArrayWriter a) where
  mempty :: ArrayWriter a
mempty = ArrayWriter a
forall a. ArrayWriter a
emptyWriter

addWriters :: ArrayWriter a %1-> ArrayWriter a %1-> ArrayWriter a
addWriters :: forall a. ArrayWriter a %1 -> ArrayWriter a %1 -> ArrayWriter a
addWriters (ArrayWriter DArray a %1 -> ()
k1 Int
l1) (ArrayWriter DArray a %1 -> ()
k2 Int
l2) =
  (DArray a %1 -> ()) %1 -> Int -> ArrayWriter a
forall a. (DArray a %1 -> ()) -> Int -> ArrayWriter a
ArrayWriter
    (\DArray a
darr ->
      (Int -> DArray a %1 -> (DArray a, DArray a)
forall a. Int -> DArray a %1 -> (DArray a, DArray a)
DArray.split Int
l1 DArray a
darr) (DArray a, DArray a) %1 -> ((DArray a, DArray a) %1 -> ()) %1 -> ()
forall a b. a %1 -> (a %1 -> b) %1 -> b
& \(DArray a
darr1,DArray a
darr2) -> ((), ()) %1 -> ()
forall a. Consumable a => a %1 -> ()
consume (DArray a %1 -> ()
k1 DArray a
darr1, DArray a %1 -> ()
k2 DArray a
darr2))
    (Int
l1Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+Int
l2)

emptyWriter :: ArrayWriter a
emptyWriter :: forall a. ArrayWriter a
emptyWriter = (DArray a %1 -> ()) -> Int -> ArrayWriter a
forall a. (DArray a %1 -> ()) -> Int -> ArrayWriter a
ArrayWriter DArray a %1 -> ()
forall a. HasCallStack => DArray a %1 -> ()
DArray.dropEmpty Int
0
-- Remark. @emptyWriter@ assumes we can split a destination array at 0.