{-# OPTIONS_HADDOCK hide #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PartialTypeSignatures #-}

module Data.Array.Polarized.Pull.Internal where

import qualified Data.Functor.Linear as Data
import Prelude.Linear
import qualified Prelude
import Data.Vector (Vector)
import qualified Data.Vector as Vector

import qualified Unsafe.Linear as Unsafe

-- | A pull array is an array from which it is easy to extract elements, and
-- this can be done in any order. The linear consumption of a pull array means
-- each element is consumed exactly once, but the length can be accessed
-- freely.
data Array a where
  Array :: (Int -> a) -> Int -> Array a
  deriving NonEmpty (Array a) -> Array a
Array a -> Array a -> Array a
(Array a -> Array a -> Array a)
-> (NonEmpty (Array a) -> Array a)
-> (forall b. Integral b => b -> Array a -> Array a)
-> Semigroup (Array a)
forall b. Integral b => b -> Array a -> Array a
forall a. NonEmpty (Array a) -> Array a
forall a. Array a -> Array a -> Array a
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall a b. Integral b => b -> Array a -> Array a
stimes :: forall b. Integral b => b -> Array a -> Array a
$cstimes :: forall a b. Integral b => b -> Array a -> Array a
sconcat :: NonEmpty (Array a) -> Array a
$csconcat :: forall a. NonEmpty (Array a) -> Array a
<> :: Array a -> Array a -> Array a
$c<> :: forall a. Array a -> Array a -> Array a
Prelude.Semigroup via NonLinear (Array a)
  -- In the linear consumption of a pull array f n, (f i) should be consumed
  -- linearly for every 0 <= i < n. The exported functions (from non-internal
  -- modules) should enforce this invariant, but the current type of PullArray
  -- does not.

instance Data.Functor Array where
  fmap :: forall a b. (a %1 -> b) -> Array a %1 -> Array b
fmap a %1 -> b
f (Array Int -> a
g Int
n) = (Int -> b) -> Int -> Array b
forall a. (Int -> a) -> Int -> Array a
fromFunction (\Int
x -> a %1 -> b
f (Int -> a
g Int
x)) Int
n

-- XXX: This should be well-typed without the unsafe, but it isn't accepted:
-- the pull array type probably isn't the ideal choice (making Array linear in
-- (Int -> a) would mean only one value could be taken out of the Array (which
-- is interesting in and of itself: I think this is like an n-ary With), and
-- changing the other arrows makes no difference)


-- | Produce a pull array of lenght 1 consisting of solely the given element.
singleton :: a %1-> Array a
singleton :: forall a. a %1 -> Array a
singleton = (a -> Array a) %1 -> a %1 -> Array a
forall a b (p :: Multiplicity). (a %p -> b) %1 -> a %1 -> b
Unsafe.toLinear (\a
x -> (Int -> a) -> Int -> Array a
forall a. (Int -> a) -> Int -> Array a
fromFunction (\Int
_ -> a
x) Int
1)

-- | @zip [x1, ..., xn] [y1, ..., yn] = [(x1,y1), ..., (xn,yn)]@
-- __Partial:__ `zip [x1,x2,...,xn] [y1,y2,...,yp]` is an error if @n ≠ p@.
zip :: Array a %1-> Array b %1-> Array (a,b)
zip :: forall a b. Array a %1 -> Array b %1 -> Array (a, b)
zip (Array Int -> a
g Int
n) (Array Int -> b
h Int
m)
  | Int
n Int %1 -> Int %1 -> Bool
forall a. Eq a => a %1 -> a %1 -> Bool
/= Int
m    = [Char] -> Array (a, b)
forall a. HasCallStack => [Char] -> a
error [Char]
"Polarized.zip: size mismatch"
  | Bool
otherwise = (Int -> (a, b)) -> Int -> Array (a, b)
forall a. (Int -> a) -> Int -> Array a
fromFunction (\Int
k -> (Int -> a
g Int
k, Int -> b
h Int
k)) Int
n

-- | Concatenate two pull arrays.
append :: Array a %1-> Array a %1-> Array a
append :: forall a. Array a %1 -> Array a %1 -> Array a
append (Array Int -> a
f Int
m) (Array Int -> a
g Int
n) = (Int -> a) -> Int -> Array a
forall a. (Int -> a) -> Int -> Array a
Array Int -> a
h (Int
m Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+ Int
n)
  where h :: Int -> a
h Int
k = if Int
k Int %1 -> Int %1 -> Bool
forall a. Ord a => a %1 -> a %1 -> Bool
< Int
m
                 then Int -> a
f Int
k
                 else Int -> a
g (Int
kInt %1 -> Int %1 -> Int
forall a. AdditiveGroup a => a %1 -> a %1 -> a
-Int
m)

-- | Creates a pull array of given size, filled with the given element.
make :: a -> Int -> Array a
make :: forall a. a -> Int -> Array a
make a
x Int
n = (Int -> a) -> Int -> Array a
forall a. (Int -> a) -> Int -> Array a
fromFunction (a %1 -> Int -> a
forall a b. a %1 -> b -> a
const a
x) Int
n

instance Semigroup (Array a) where
  <> :: Array a %1 -> Array a %1 -> Array a
(<>) = Array a %1 -> Array a %1 -> Array a
forall a. Array a %1 -> Array a %1 -> Array a
append

-- | A right-fold of a pull array.
foldr :: (a %1-> b %1-> b) -> b %1-> Array a %1-> b
foldr :: forall a b. (a %1 -> b %1 -> b) -> b %1 -> Array a %1 -> b
foldr a %1 -> b %1 -> b
f b
z (Array Int -> a
g Int
n) = (a %1 -> b %1 -> b) -> b %1 -> (Int -> a) -> Int -> b
forall {w} {w}.
(w %1 -> w %1 -> w) -> w %1 -> (Int -> w) -> Int -> w
go a %1 -> b %1 -> b
f b
z Int -> a
g Int
n
  where go :: (_ %1-> _ %1-> _) -> _ %1-> _ -> _ -> _
        go :: (w %1 -> w %1 -> w) -> w %1 -> (Int -> w) -> Int -> w
go w %1 -> w %1 -> w
_ w
z' Int -> w
_ Int
0 = w
z'
        go w %1 -> w %1 -> w
f' w
z' Int -> w
g' Int
k = (w %1 -> w %1 -> w) -> w %1 -> (Int -> w) -> Int -> w
go w %1 -> w %1 -> w
f' (w %1 -> w %1 -> w
f' (Int -> w
g' (Int
kInt %1 -> Int %1 -> Int
forall a. AdditiveGroup a => a %1 -> a %1 -> a
-Int
1)) w
z') Int -> w
g' (Int
kInt %1 -> Int %1 -> Int
forall a. AdditiveGroup a => a %1 -> a %1 -> a
-Int
1)
        -- go is strict in its last argument

-- | Extract the length of an array, and give back the original array.
findLength :: Array a %1-> (Int, Array a)
findLength :: forall a. Array a %1 -> (Int, Array a)
findLength (Array Int -> a
f Int
n) = (Int
n, (Int -> a) -> Int -> Array a
forall a. (Int -> a) -> Int -> Array a
Array Int -> a
f Int
n)

-- | @'fromFunction' arrIndexer len@ constructs a pull array given a function
-- @arrIndexer@ that goes from an array index to array values and a specified
-- length @len@.
fromFunction :: (Int -> a) -> Int -> Array a
fromFunction :: forall a. (Int -> a) -> Int -> Array a
fromFunction Int -> a
f Int
n = (Int -> a) -> Int -> Array a
forall a. (Int -> a) -> Int -> Array a
Array Int -> a
f' Int
n
  where f' :: Int -> a
f' Int
k
          | Int
k Int %1 -> Int %1 -> Bool
forall a. Ord a => a %1 -> a %1 -> Bool
< Int
0 = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Pull.Array: negative index"
          | Int
k Int %1 -> Int %1 -> Bool
forall a. Ord a => a %1 -> a %1 -> Bool
>= Int
n = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Pull.Array: index too large"
          | Bool
otherwise = Int -> a
f Int
k

-- XXX: this is used internally to ensure out of bounds errors occur, but
-- is unnecessary if the input function can be assumed to already have bounded
-- domain, for instance in `append`.

-- | This is a convenience function for @alloc . transfer@
toVector :: Array a %1-> Vector a
toVector :: forall a. Array a %1 -> Vector a
toVector (Array Int -> a
f Int
n) = Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
Vector.generate Int
n Int -> a
f
-- TODO: A test to make sure alloc . transfer == toVector

-- | @'split' n v = (vl, vr)@ such that @vl@ has length @n@.
--
-- 'split' is total: if @n@ is larger than the length of @v@,
-- then @vr@ is empty.
split :: Int -> Array a %1-> (Array a, Array a)
split :: forall a. Int -> Array a %1 -> (Array a, Array a)
split Int
k (Array Int -> a
f Int
n) = ((Int -> a) -> Int -> Array a
forall a. (Int -> a) -> Int -> Array a
fromFunction Int -> a
f (Int %1 -> Int %1 -> Int
forall a. (Dupable a, Ord a) => a %1 -> a %1 -> a
min Int
k Int
n), (Int -> a) -> Int -> Array a
forall a. (Int -> a) -> Int -> Array a
fromFunction (\Int
x -> Int -> a
f (Int
xInt %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+Int
k)) (Int %1 -> Int %1 -> Int
forall a. (Dupable a, Ord a) => a %1 -> a %1 -> a
max (Int
nInt %1 -> Int %1 -> Int
forall a. AdditiveGroup a => a %1 -> a %1 -> a
-Int
k) Int
0))

-- | Reverse a pull array.
reverse :: Array a %1-> Array a
reverse :: forall a. Array a %1 -> Array a
reverse (Array Int -> a
f Int
n) = (Int -> a) -> Int -> Array a
forall a. (Int -> a) -> Int -> Array a
Array (\Int
x -> Int -> a
f (Int
nInt %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+Int
1Int %1 -> Int %1 -> Int
forall a. AdditiveGroup a => a %1 -> a %1 -> a
-Int
x)) Int
n

-- | Index a pull array (without checking bounds)
index :: Array a %1-> Int -> (a, Array a)
index :: forall a. Array a %1 -> Int -> (a, Array a)
index (Array Int -> a
f Int
n) Int
ix = (Int -> a
f Int
ix, (Int -> a) -> Int -> Array a
forall a. (Int -> a) -> Int -> Array a
Array Int -> a
f Int
n)