{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
{-# OPTIONS_HADDOCK hide #-}

module Data.Array.Polarized.Pull.Internal where

import qualified Data.Functor.Linear as Data
import Data.Vector (Vector)
import qualified Data.Vector as Vector
import Prelude.Linear
import qualified Unsafe.Linear as Unsafe
import qualified Prelude

-- | A pull array is an array from which it is easy to extract elements, and
-- this can be done in any order. The linear consumption of a pull array means
-- each element is consumed exactly once, but the length can be accessed
-- freely.
data Array a where
  Array :: (Int -> a) -> Int -> Array a
  deriving (NonEmpty (Array a) -> Array a
Array a -> Array a -> Array a
forall b. Integral b => b -> Array a -> Array a
forall a. NonEmpty (Array a) -> Array a
forall a. Array a -> Array a -> Array a
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall a b. Integral b => b -> Array a -> Array a
stimes :: forall b. Integral b => b -> Array a -> Array a
$cstimes :: forall a b. Integral b => b -> Array a -> Array a
sconcat :: NonEmpty (Array a) -> Array a
$csconcat :: forall a. NonEmpty (Array a) -> Array a
<> :: Array a -> Array a -> Array a
$c<> :: forall a. Array a -> Array a -> Array a
Prelude.Semigroup) via NonLinear (Array a)

-- In the linear consumption of a pull array f n, (f i) should be consumed
-- linearly for every 0 <= i < n. The exported functions (from non-internal
-- modules) should enforce this invariant, but the current type of PullArray
-- does not.

instance Data.Functor Array where
  fmap :: forall a b. (a %1 -> b) -> Array a %1 -> Array b
fmap a %1 -> b
f (Array Int -> a
g Int
n) = forall a. (Int -> a) -> Int -> Array a
fromFunction (\Int
x -> a %1 -> b
f (Int -> a
g Int
x)) Int
n

-- XXX: This should be well-typed without the unsafe, but it isn't accepted:
-- the pull array type probably isn't the ideal choice (making Array linear in
-- (Int -> a) would mean only one value could be taken out of the Array (which
-- is interesting in and of itself: I think this is like an n-ary With), and
-- changing the other arrows makes no difference)

-- | Produce a pull array of lenght 1 consisting of solely the given element.
singleton :: a %1 -> Array a
singleton :: forall a. a %1 -> Array a
singleton = forall a b (p :: Multiplicity) (x :: Multiplicity).
(a %p -> b) %1 -> a %x -> b
Unsafe.toLinear (\a
x -> forall a. (Int -> a) -> Int -> Array a
fromFunction (\Int
_ -> a
x) Int
1)

-- | @zip [x1, ..., xn] [y1, ..., yn] = [(x1,y1), ..., (xn,yn)]@
-- __Partial:__ `zip [x1,x2,...,xn] [y1,y2,...,yp]` is an error if @n ≠ p@.
zip :: Array a %1 -> Array b %1 -> Array (a, b)
zip :: forall a b. Array a %1 -> Array b %1 -> Array (a, b)
zip (Array Int -> a
g Int
n) (Array Int -> b
h Int
m)
  | Int
n forall a. Eq a => a %1 -> a %1 -> Bool
/= Int
m = forall a. HasCallStack => [Char] -> a
error [Char]
"Polarized.zip: size mismatch"
  | Bool
otherwise = forall a. (Int -> a) -> Int -> Array a
fromFunction (\Int
k -> (Int -> a
g Int
k, Int -> b
h Int
k)) Int
n

-- | Concatenate two pull arrays.
append :: Array a %1 -> Array a %1 -> Array a
append :: forall a. Array a %1 -> Array a %1 -> Array a
append (Array Int -> a
f Int
m) (Array Int -> a
g Int
n) = forall a. (Int -> a) -> Int -> Array a
Array Int -> a
h (Int
m forall a. Additive a => a %1 -> a %1 -> a
+ Int
n)
  where
    h :: Int -> a
h Int
k =
      if Int
k forall a. Ord a => a %1 -> a %1 -> Bool
< Int
m
        then Int -> a
f Int
k
        else Int -> a
g (Int
k forall a. AdditiveGroup a => a %1 -> a %1 -> a
- Int
m)

-- | Creates a pull array of given size, filled with the given element.
make :: a -> Int -> Array a
make :: forall a. a -> Int -> Array a
make a
x Int
n = forall a. (Int -> a) -> Int -> Array a
fromFunction (forall a b (q :: Multiplicity). a %q -> b -> a
const a
x) Int
n

instance Semigroup (Array a) where
  <> :: Array a %1 -> Array a %1 -> Array a
(<>) = forall a. Array a %1 -> Array a %1 -> Array a
append

-- | A right-fold of a pull array.
foldr :: (a %1 -> b %1 -> b) -> b %1 -> Array a %1 -> b
foldr :: forall a b. (a %1 -> b %1 -> b) -> b %1 -> Array a %1 -> b
foldr a %1 -> b %1 -> b
f b
z (Array Int -> a
g Int
n) = forall {w} {w}.
(w %1 -> w %1 -> w) -> w %1 -> (Int -> w) -> Int -> w
go a %1 -> b %1 -> b
f b
z Int -> a
g Int
n
  where
    go :: (_ %1 -> _ %1 -> _) -> _ %1 -> _ -> _ -> _
    go :: (w %1 -> w %1 -> w) -> w %1 -> (Int -> w) -> Int -> w
go w %1 -> w %1 -> w
_ w
z' Int -> w
_ Int
0 = w
z'
    go w %1 -> w %1 -> w
f' w
z' Int -> w
g' Int
k = (w %1 -> w %1 -> w) -> w %1 -> (Int -> w) -> Int -> w
go w %1 -> w %1 -> w
f' (w %1 -> w %1 -> w
f' (Int -> w
g' (Int
k forall a. AdditiveGroup a => a %1 -> a %1 -> a
- Int
1)) w
z') Int -> w
g' (Int
k forall a. AdditiveGroup a => a %1 -> a %1 -> a
- Int
1)

-- go is strict in its last argument

-- | Extract the length of an array, and give back the original array.
findLength :: Array a %1 -> (Int, Array a)
findLength :: forall a. Array a %1 -> (Int, Array a)
findLength (Array Int -> a
f Int
n) = (Int
n, forall a. (Int -> a) -> Int -> Array a
Array Int -> a
f Int
n)

-- | @'fromFunction' arrIndexer len@ constructs a pull array given a function
-- @arrIndexer@ that goes from an array index to array values and a specified
-- length @len@.
fromFunction :: (Int -> a) -> Int -> Array a
fromFunction :: forall a. (Int -> a) -> Int -> Array a
fromFunction Int -> a
f Int
n = forall a. (Int -> a) -> Int -> Array a
Array Int -> a
f' Int
n
  where
    f' :: Int -> a
f' Int
k
      | Int
k forall a. Ord a => a %1 -> a %1 -> Bool
< Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"Pull.Array: negative index"
      | Int
k forall a. Ord a => a %1 -> a %1 -> Bool
>= Int
n = forall a. HasCallStack => [Char] -> a
error [Char]
"Pull.Array: index too large"
      | Bool
otherwise = Int -> a
f Int
k

-- XXX: this is used internally to ensure out of bounds errors occur, but
-- is unnecessary if the input function can be assumed to already have bounded
-- domain, for instance in `append`.

-- | This is a convenience function for @alloc . transfer@
toVector :: Array a %1 -> Vector a
toVector :: forall a. Array a %1 -> Vector a
toVector (Array Int -> a
f Int
n) = forall a. Int -> (Int -> a) -> Vector a
Vector.generate Int
n Int -> a
f

-- TODO: A test to make sure alloc . transfer == toVector

-- | @'split' n v = (vl, vr)@ such that @vl@ has length @n@.
--
-- 'split' is total: if @n@ is larger than the length of @v@,
-- then @vr@ is empty.
split :: Int -> Array a %1 -> (Array a, Array a)
split :: forall a. Int -> Array a %1 -> (Array a, Array a)
split Int
k (Array Int -> a
f Int
n) = (forall a. (Int -> a) -> Int -> Array a
fromFunction Int -> a
f (forall a. (Dupable a, Ord a) => a %1 -> a %1 -> a
min Int
k Int
n), forall a. (Int -> a) -> Int -> Array a
fromFunction (\Int
x -> Int -> a
f (Int
x forall a. Additive a => a %1 -> a %1 -> a
+ Int
k)) (forall a. (Dupable a, Ord a) => a %1 -> a %1 -> a
max (Int
n forall a. AdditiveGroup a => a %1 -> a %1 -> a
- Int
k) Int
0))

-- | Reverse a pull array.
reverse :: Array a %1 -> Array a
reverse :: forall a. Array a %1 -> Array a
reverse (Array Int -> a
f Int
n) = forall a. (Int -> a) -> Int -> Array a
Array (\Int
x -> Int -> a
f (Int
n forall a. Additive a => a %1 -> a %1 -> a
+ Int
1 forall a. AdditiveGroup a => a %1 -> a %1 -> a
- Int
x)) Int
n

-- | Index a pull array (without checking bounds)
index :: Array a %1 -> Int -> (a, Array a)
index :: forall a. Array a %1 -> Int -> (a, Array a)
index (Array Int -> a
f Int
n) Int
ix = (Int -> a
f Int
ix, forall a. (Int -> a) -> Int -> Array a
Array Int -> a
f Int
n)