-- Copyright 2020 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
module Data.Array.Internal(module Data.Array.Internal) where
import Control.DeepSeq
import Data.Data(Data)
import qualified Data.DList as DL
import Data.Kind (Type)
import Data.List(foldl', zipWith4, zipWith5, sortBy, sortOn)
import Data.Proxy
import GHC.Exts(Constraint, build)
import GHC.Generics(Generic)
import GHC.TypeLits(KnownNat, natVal)
import Text.PrettyPrint
import Text.PrettyPrint.HughesPJClass

{- HLINT ignore "Reduce duplication" -}

-- The underlying storage of values must be an instance of Vector.
-- For some types, like unboxed vectors, we require an extra
-- constraint on the elements, which VecElem allows you to express.
-- For vector types that don't need the constraint it can be set
-- to some dummy class.
-- | The 'Vector' class is the interface to the underlying storage for the arrays.
-- The operations map straight to operations for 'Vector'.
class Vector v where
  type VecElem v :: Type -> Constraint
  vIndex    :: (VecElem v a) => v a -> Int -> a
  vLength   :: (VecElem v a) => v a -> Int
  vToList   :: (VecElem v a) => v a -> [a]
  vFromList :: (VecElem v a) => [a] -> v a
  vSingleton:: (VecElem v a) => a -> v a
  vReplicate:: (VecElem v a) => Int -> a -> v a
  vMap      :: (VecElem v a, VecElem v b) => (a -> b) -> v a -> v b
  vZipWith  :: (VecElem v a, VecElem v b, VecElem v c) => (a -> b -> c) -> v a -> v b -> v c
  vZipWith3 :: (VecElem v a, VecElem v b, VecElem v c, VecElem v d) => (a -> b -> c -> d) -> v a -> v b -> v c -> v d
  vZipWith4 :: (VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e) => (a -> b -> c -> d -> e) -> v a -> v b -> v c -> v d -> v e
  vZipWith5 :: (VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e, VecElem v f) => (a -> b -> c -> d -> e -> f) -> v a -> v b -> v c -> v d -> v e -> v f
  vAppend   :: (VecElem v a) => v a -> v a -> v a
  vConcat   :: (VecElem v a) => [v a] -> v a
  vFold     :: (VecElem v a) => (a -> a -> a) -> a -> v a -> a
  vSlice    :: (VecElem v a) => Int -> Int -> v a -> v a
  vSum      :: (VecElem v a, Num a) => v a -> a
  vProduct  :: (VecElem v a, Num a) => v a -> a
  vMaximum  :: (VecElem v a, Ord a) => v a -> a
  vMinimum  :: (VecElem v a, Ord a) => v a -> a
  vUpdate   :: (VecElem v a) => v a -> [(Int, a)] -> v a
  vGenerate :: (VecElem v a) => Int -> (Int -> a) -> v a
  vAll      :: (VecElem v a) => (a -> Bool) -> v a -> Bool
  vAny      :: (VecElem v a) => (a -> Bool) -> v a -> Bool

class None a
instance None a

-- This instance is not used anywheer.  It serves more as a reference semantics.
instance Vector [] where
  type VecElem [] = None
  vIndex :: forall a. VecElem [] a => [a] -> N -> a
vIndex = forall a. [a] -> N -> a
(!!)
  vLength :: forall a. VecElem [] a => [a] -> N
vLength = forall (t :: * -> *) a. Foldable t => t a -> N
length
  vToList :: forall a. VecElem [] a => [a] -> [a]
vToList = forall a. a -> a
id
  vFromList :: forall a. VecElem [] a => [a] -> [a]
vFromList = forall a. a -> a
id
  vSingleton :: forall a. VecElem [] a => a -> [a]
vSingleton = forall (f :: * -> *) a. Applicative f => a -> f a
pure
  vReplicate :: forall a. VecElem [] a => N -> a -> [a]
vReplicate = forall a. N -> a -> [a]
replicate
  vMap :: forall a b. (VecElem [] a, VecElem [] b) => (a -> b) -> [a] -> [b]
vMap = forall a b. (a -> b) -> [a] -> [b]
map
  vZipWith :: forall a b c.
(VecElem [] a, VecElem [] b, VecElem [] c) =>
(a -> b -> c) -> [a] -> [b] -> [c]
vZipWith = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
  vZipWith3 :: forall a b c d.
(VecElem [] a, VecElem [] b, VecElem [] c, VecElem [] d) =>
(a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
vZipWith3 = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
  vZipWith4 :: forall a b c d e.
(VecElem [] a, VecElem [] b, VecElem [] c, VecElem [] d,
 VecElem [] e) =>
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
vZipWith4 = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4
  vZipWith5 :: forall a b c d e f.
(VecElem [] a, VecElem [] b, VecElem [] c, VecElem [] d,
 VecElem [] e, VecElem [] f) =>
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
vZipWith5 = forall a b c d e f.
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
zipWith5
  vAppend :: forall a. VecElem [] a => [a] -> [a] -> [a]
vAppend = forall a. [a] -> [a] -> [a]
(++)
  vConcat :: forall a. VecElem [] a => [[a]] -> [a]
vConcat = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
  vFold :: forall a. VecElem [] a => (a -> a -> a) -> a -> [a] -> a
vFold = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
  vSlice :: forall a. VecElem [] a => N -> N -> [a] -> [a]
vSlice N
o N
n = forall a. N -> [a] -> [a]
take N
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. N -> [a] -> [a]
drop N
o
  vSum :: forall a. (VecElem [] a, Num a) => [a] -> a
vSum = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
  vProduct :: forall a. (VecElem [] a, Num a) => [a] -> a
vProduct = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
  vMaximum :: forall a. (VecElem [] a, Ord a) => [a] -> a
vMaximum = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum
  vMinimum :: forall a. (VecElem [] a, Ord a) => [a] -> a
vMinimum = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum
  vUpdate :: forall a. VecElem [] a => [a] -> [(N, a)] -> [a]
vUpdate [a]
xs [(N, a)]
us = forall {t} {a}. (Ord t, Num t) => [a] -> [(t, a)] -> t -> [a]
loop [a]
xs (forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst [(N, a)]
us) N
0
    where
      loop :: [a] -> [(t, a)] -> t -> [a]
loop [] [] t
_ = []
      loop [] ((t, a)
_:[(t, a)]
_) t
_ = forall a. HasCallStack => String -> a
error String
"vUpdate: out of bounds"
      loop [a]
as [] t
_ = [a]
as
      loop (a
a:[a]
as) ias :: [(t, a)]
ias@((t
i,a
a'):[(t, a)]
ias') t
n =
        case forall a. Ord a => a -> a -> Ordering
compare t
i t
n of
          Ordering
LT -> forall a. HasCallStack => String -> a
error String
"vUpdate: bad index"
          Ordering
EQ -> a
a' forall a. a -> [a] -> [a]
: [a] -> [(t, a)] -> t -> [a]
loop [a]
as [(t, a)]
ias' (t
nforall a. Num a => a -> a -> a
+t
1)
          Ordering
GT -> a
a  forall a. a -> [a] -> [a]
: [a] -> [(t, a)] -> t -> [a]
loop [a]
as [(t, a)]
ias  (t
nforall a. Num a => a -> a -> a
+t
1)
  vGenerate :: forall a. VecElem [] a => N -> (N -> a) -> [a]
vGenerate N
n N -> a
f = forall a b. (a -> b) -> [a] -> [b]
map N -> a
f [N
0 .. N
nforall a. Num a => a -> a -> a
-N
1]
  vAll :: forall a. VecElem [] a => (a -> Bool) -> [a] -> Bool
vAll = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
  vAny :: forall a. VecElem [] a => (a -> Bool) -> [a] -> Bool
vAny = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any

prettyShowL :: (Pretty a) => PrettyLevel -> a -> String
prettyShowL :: forall a. Pretty a => PrettyLevel -> a -> String
prettyShowL PrettyLevel
l = Doc -> String
render forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => PrettyLevel -> Rational -> a -> Doc
pPrintPrec PrettyLevel
l Rational
0

-- We expect all N to be non-negative, but we use Int for convenience.
type N = Int

-- | The type /T/ is the internal type of arrays.  In general,
-- operations on /T/ do no sanity checking as that should be done
-- at the point of call.
--
-- To avoid manipulating the data the indexing into the vector containing
-- the data is somewhat complex.  To find where item /i/ of the outermost
-- dimension starts you calculate vector index @offset + i*strides[0]@.
-- To find where item /i,j/ of the two outermost dimensions is you
-- calculate vector index @offset + i*strides[0] + j*strides[1]@, etc.
data T v a = T
    { forall (v :: * -> *) a. T v a -> [N]
strides :: [N]      -- length is tensor rank
    , forall (v :: * -> *) a. T v a -> N
offset  :: !N       -- offset into vector of values
    , forall (v :: * -> *) a. T v a -> v a
values  :: !(v a)   -- actual values
    }
    deriving (N -> T v a -> ShowS
forall a.
(N -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (v :: * -> *) a. Show (v a) => N -> T v a -> ShowS
forall (v :: * -> *) a. Show (v a) => [T v a] -> ShowS
forall (v :: * -> *) a. Show (v a) => T v a -> String
showList :: [T v a] -> ShowS
$cshowList :: forall (v :: * -> *) a. Show (v a) => [T v a] -> ShowS
show :: T v a -> String
$cshow :: forall (v :: * -> *) a. Show (v a) => T v a -> String
showsPrec :: N -> T v a -> ShowS
$cshowsPrec :: forall (v :: * -> *) a. Show (v a) => N -> T v a -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (v :: * -> *) a x. Rep (T v a) x -> T v a
forall (v :: * -> *) a x. T v a -> Rep (T v a) x
$cto :: forall (v :: * -> *) a x. Rep (T v a) x -> T v a
$cfrom :: forall (v :: * -> *) a x. T v a -> Rep (T v a) x
Generic, T v a -> DataType
T v a -> Constr
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. N -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
forall {v :: * -> *} {a}.
(Typeable v, Typeable a, Data (v a)) =>
Typeable (T v a)
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> DataType
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> Constr
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> T v a -> T v a
forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
N -> (forall d. Data d => d -> u) -> T v a -> u
forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> T v a -> [u]
forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
forall (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (T v a))
forall (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
$cgmapMo :: forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
$cgmapMp :: forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
$cgmapM :: forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
gmapQi :: forall u. N -> (forall d. Data d => d -> u) -> T v a -> u
$cgmapQi :: forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
N -> (forall d. Data d => d -> u) -> T v a -> u
gmapQ :: forall u. (forall d. Data d => d -> u) -> T v a -> [u]
$cgmapQ :: forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> T v a -> [u]
gmapQr :: forall r r'.
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
$cgmapQr :: forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
gmapQl :: forall r r'.
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
$cgmapQl :: forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
gmapT :: (forall b. Data b => b -> b) -> T v a -> T v a
$cgmapT :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> T v a -> T v a
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
$cdataCast2 :: forall (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (T v a))
$cdataCast1 :: forall (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (T v a))
dataTypeOf :: T v a -> DataType
$cdataTypeOf :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> DataType
toConstr :: T v a -> Constr
$ctoConstr :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> Constr
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
$cgunfold :: forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
$cgfoldl :: forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
Data)

instance NFData (v a) => NFData (T v a)

-- | The shape of an array is a list of its dimensions.
type ShapeL = [Int]

badShape :: ShapeL -> Bool
badShape :: [N] -> Bool
badShape = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. Ord a => a -> a -> Bool
< N
0)

-- When shapes match, we can be efficient and use loop-fused comparisons instead
-- of materializing a list.
equalT :: (Vector v, VecElem v a, Eq a, Eq (v a))
                  => ShapeL -> T v a -> T v a -> Bool
equalT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a, Eq (v a)) =>
[N] -> T v a -> T v a -> Bool
equalT [N]
s T v a
x T v a
y | forall (v :: * -> *) a. T v a -> [N]
strides T v a
x forall a. Eq a => a -> a -> Bool
== forall (v :: * -> *) a. T v a -> [N]
strides T v a
y
               Bool -> Bool -> Bool
&& forall (v :: * -> *) a. T v a -> N
offset T v a
x forall a. Eq a => a -> a -> Bool
== forall (v :: * -> *) a. T v a -> N
offset T v a
y
               Bool -> Bool -> Bool
&& forall (v :: * -> *) a. T v a -> v a
values T v a
x forall a. Eq a => a -> a -> Bool
== forall (v :: * -> *) a. T v a -> v a
values T v a
y = Bool
True
             | Bool
otherwise = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
s T v a
x forall a. Eq a => a -> a -> Bool
== forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
s T v a
y

-- Note this assumes the shape is the same for both Vectors.
compareT :: (Vector v, VecElem v a, Ord a, Ord (v a))
            => ShapeL -> T v a -> T v a -> Ordering
compareT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a, Ord (v a)) =>
[N] -> T v a -> T v a -> Ordering
compareT [N]
s T v a
x T v a
y = forall a. Ord a => a -> a -> Ordering
compare (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
s T v a
x) (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
s T v a
y)

-- Given the dimensions, return the stride in the underlying vector
-- for each dimension.  The first element of the list is the total length.
{-# INLINE getStridesT #-}
getStridesT :: ShapeL -> [N]
getStridesT :: [N] -> [N]
getStridesT = forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) N
1

-- Convert an array to a list by indexing through all the elements.
-- The first argument is the array shape.
-- XXX Copy special cases from Tensor.
{-# INLINE toListT #-}
toListT :: (Vector v, VecElem v a) => ShapeL -> T v a -> [a]
toListT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> [a]
toListT [N]
sh a :: T v a
a@(T [N]
ss0 N
o0 v a
v)
  | forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> Bool
isCanonicalT ([N] -> [N]
getStridesT [N]
sh) T v a
a = forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> [a]
vToList v a
v
  | Bool
otherwise = forall a. (forall b. (a -> b -> b) -> b -> b) -> [a]
build forall a b. (a -> b) -> a -> b
$ \a -> b -> b
cons b
nil ->
      -- TODO: because unScalarT uses vIndex, this has unnecessary bounds
      -- checks.  We should expose an unchecked indexing function in the Vector
      -- class, add top-level bounds checks to cover the full range we'll
      -- access, and then do all accesses with the unchecked version.
      let go :: [N] -> [N] -> N -> b -> b
go []     [N]
ss N
o b
rest = a -> b -> b
cons (forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT (forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [N]
ss N
o v a
v)) b
rest
          go (N
n:[N]
ns) [N]
ss N
o b
rest = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
            (\N
i -> case forall (v :: * -> *) a. T v a -> N -> T v a
indexT (forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [N]
ss N
o v a
v) N
i of T [N]
ss' N
o' v a
_ -> [N] -> [N] -> N -> b -> b
go [N]
ns [N]
ss' N
o')
            b
rest
            [N
0..N
nforall a. Num a => a -> a -> a
-N
1]
      in  [N] -> [N] -> N -> b -> b
go [N]
sh [N]
ss0 N
o0 b
nil

-- | Check if the strides are canonical, i.e., if the vector have the natural layout.
-- XXX Copy special cases from Tensor.
{-# INLINE isCanonicalT #-}
isCanonicalT :: (Vector v, VecElem v a) => [N] -> T v a -> Bool
isCanonicalT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> Bool
isCanonicalT (N
n:[N]
ss') (T [N]
ss N
o v a
v) =
    N
o forall a. Eq a => a -> a -> Bool
== N
0 Bool -> Bool -> Bool
&&         -- Vector offset is 0
    [N]
ss forall a. Eq a => a -> a -> Bool
== [N]
ss' Bool -> Bool -> Bool
&&      -- All strides are normal
    forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N
vLength v a
v forall a. Eq a => a -> a -> Bool
== N
n    -- The vector is the right size
isCanonicalT [N]
_ T v a
_ = forall a. HasCallStack => String -> a
error String
"impossible"

-- Convert a value to a scalar array.
{-# INLINE scalarT #-}
scalarT :: (Vector v, VecElem v a) => a -> T v a
scalarT :: forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> T v a
scalarT = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [] N
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton

-- Convert a scalar array to the actual value.
{-# INLINE unScalarT #-}
unScalarT :: (Vector v, VecElem v a) => T v a -> a
unScalarT :: forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT (T [N]
_ N
o v a
v) = forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N -> a
vIndex v a
v N
o

-- Make a constant array.
{-# INLINE constantT #-}
constantT :: (Vector v, VecElem v a) => ShapeL -> a -> T v a
constantT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> a -> T v a
constantT [N]
sh a
x = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const N
0) [N]
sh) N
0 (forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton a
x)

-- TODO: change to return a list of vectors.
-- Convert an array to a vector in the natural order.
{-# INLINE toVectorT #-}
toVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
toVectorT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh a :: T v a
a@(T [N]
ats N
ao v a
v) =
  let N
l : [N]
ts' = [N] -> [N]
getStridesT [N]
sh
      -- Are strides ok from this point?
      oks :: [Bool]
oks = forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr Bool -> Bool -> Bool
(&&) Bool
True (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Eq a => a -> a -> Bool
(==) [N]
ats [N]
ts')
      loop :: [Bool] -> [N] -> [N] -> N -> DList (v a)
loop [Bool]
_ [] [N]
_ N
o =
        forall a. a -> DList a
DL.singleton (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
N -> N -> v a -> v a
vSlice N
o N
1 v a
v)
      loop (Bool
b:[Bool]
bs) (N
s:[N]
ss) (N
t:[N]
ts) N
o =
        if Bool
b then
          -- All strides normal from this point,
          -- so just take a slice of the underlying vector.
          forall a. a -> DList a
DL.singleton (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
N -> N -> v a -> v a
vSlice N
o (N
sforall a. Num a => a -> a -> a
*N
t) v a
v)
        else
          -- Strides are not normal, collect slices.
          forall a. [DList a] -> DList a
DL.concat [ [Bool] -> [N] -> [N] -> N -> DList (v a)
loop [Bool]
bs [N]
ss [N]
ts (N
iforall a. Num a => a -> a -> a
*N
t forall a. Num a => a -> a -> a
+ N
o) | N
i <- [N
0 .. N
sforall a. Num a => a -> a -> a
-N
1] ]
      loop [Bool]
_ [N]
_ [N]
_ N
_ = forall a. HasCallStack => String -> a
error String
"impossible"
  in  if forall a. [a] -> a
head [Bool]
oks Bool -> Bool -> Bool
&& forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N
vLength v a
v forall a. Eq a => a -> a -> Bool
== N
l then
        -- All strides are normal, return entire vector
        v a
v
      else if [Bool]
oks forall a. [a] -> N -> a
!! forall (t :: * -> *) a. Foldable t => t a -> N
length [N]
sh then  -- Special case for speed.
        -- Innermost dimension is normal, so slices are non-trivial.
        forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall a b. (a -> b) -> a -> b
$ forall a. DList a -> [a]
DL.toList forall a b. (a -> b) -> a -> b
$ [Bool] -> [N] -> [N] -> N -> DList (v a)
loop [Bool]
oks [N]
sh [N]
ats N
ao
      else
        -- All slices would have length 1, going via a list is faster.
        forall (v :: * -> *) a. (Vector v, VecElem v a) => [a] -> v a
vFromList forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> [a]
toListT [N]
sh T v a
a

-- Convert to a vector containing the right elements,
-- but not necessarily in the right order.
{-# INLINE toUnorderedVectorT #-}
toUnorderedVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
toUnorderedVectorT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toUnorderedVectorT [N]
sh a :: T v a
a@(T [N]
ats N
ao v a
v) =
  -- Figure out if the array maps onto some contiguous slice of the vector.
  -- Do this by checking if a transposition of the array corresponds to
  -- normal strides.
  -- First sort the strides in descending order, amnd rearrange the shape the same way.
  -- Then compute the strides from this rearranged shape; these will be the normal
  -- strides for this shape.  If these strides agree with the sorted actual strides
  -- it is a transposition, and we can just slice out the relevant piece of the vector.
  let
    ([N]
ats', [N]
sh') = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Ord a => a -> a -> Ordering
compare) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [N]
ats [N]
sh
    N
l : [N]
ts' = [N] -> [N]
getStridesT [N]
sh'
  in
      if [N]
ats' forall a. Eq a => a -> a -> Bool
== [N]
ts' then
        forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
N -> N -> v a -> v a
vSlice N
ao N
l v a
v
      else
        forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v a
a

-- Convert from a vector.
{-# INLINE fromVectorT #-}
fromVectorT :: ShapeL -> v a -> T v a
fromVectorT :: forall (v :: * -> *) a. [N] -> v a -> T v a
fromVectorT [N]
sh = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T (forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ [N] -> [N]
getStridesT [N]
sh) N
0

-- Convert from a list
{-# INLINE fromListT #-}
fromListT :: (Vector v, VecElem v a) => [N] -> [a] -> T v a
fromListT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> [a] -> T v a
fromListT [N]
sh = forall (v :: * -> *) a. [N] -> v a -> T v a
fromVectorT [N]
sh forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. (Vector v, VecElem v a) => [a] -> v a
vFromList

-- Index into the outermost dimension of an array.
{-# INLINE indexT #-}
indexT :: T v a -> N -> T v a
indexT :: forall (v :: * -> *) a. T v a -> N -> T v a
indexT (T (N
s : [N]
ss) N
o v a
v) N
i = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [N]
ss (N
o forall a. Num a => a -> a -> a
+ N
i forall a. Num a => a -> a -> a
* N
s) v a
v
indexT T v a
_ N
_ = forall a. HasCallStack => String -> a
error String
"impossible"

-- Stretch the given dimensions to have arbitrary size.
-- The stretched dimensions must have size 1, and stretching is
-- done by setting the stride to 0.
{-# INLINE stretchT #-}
stretchT :: [Bool] -> T v a -> T v a
stretchT :: forall (v :: * -> *) a. [Bool] -> T v a -> T v a
stretchT [Bool]
bs (T [N]
ss N
o v a
v) = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\ Bool
b N
s -> if Bool
b then N
0 else N
s) [Bool]
bs [N]
ss) N
o v a
v

-- Map over the array elements.
{-# INLINE mapT #-}
mapT :: (Vector v, VecElem v a, VecElem v b) => ShapeL -> (a -> b) -> T v a -> T v b
mapT :: forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
[N] -> (a -> b) -> T v a -> T v b
mapT [N]
sh a -> b
f (T [N]
ss N
o v a
v) | forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [N]
sh forall a. Ord a => a -> a -> Bool
>= forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N
vLength v a
v = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [N]
ss N
o (forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
(a -> b) -> v a -> v b
vMap a -> b
f v a
v)
mapT [N]
sh a -> b
f T v a
t = forall (v :: * -> *) a. [N] -> v a -> T v a
fromVectorT [N]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
(a -> b) -> v a -> v b
vMap a -> b
f forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v a
t

-- Zip two arrays with a function.
{-# INLINE zipWithT #-}
zipWithT :: (Vector v, VecElem v a, VecElem v b, VecElem v c) =>
            ShapeL -> (a -> b -> c) -> T v a -> T v b -> T v c
zipWithT :: forall (v :: * -> *) a b c.
(Vector v, VecElem v a, VecElem v b, VecElem v c) =>
[N] -> (a -> b -> c) -> T v a -> T v b -> T v c
zipWithT [N]
sh a -> b -> c
f t :: T v a
t@(T [N]
ss N
_ v a
v) t' :: T v b
t'@(T [N]
_ N
_ v b
v') =
  case (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N
vLength v a
v, forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N
vLength v b
v') of
    (N
1, N
1) ->
      -- If both vectors have length 1, then it's a degenerate case and it's better
      -- to operate on the single element directly.
      forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [N]
ss N
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton forall a b. (a -> b) -> a -> b
$ a -> b -> c
f (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N -> a
vIndex v a
v N
0) (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N -> a
vIndex v b
v' N
0)
    (N
1, N
_) ->
      -- First vector has length 1, so use a map instead.
      forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
[N] -> (a -> b) -> T v a -> T v b
mapT [N]
sh (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N -> a
vIndex v a
v N
0 a -> b -> c
`f` ) T v b
t'
    (N
_, N
1) ->
      -- Second vector has length 1, so use a map instead.
      forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
[N] -> (a -> b) -> T v a -> T v b
mapT [N]
sh (a -> b -> c
`f` forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N -> a
vIndex v b
v' N
0) T v a
t
    (N
_, N
_) ->
      let cv :: v a
cv  = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v a
t
          cv' :: v b
cv' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v b
t'
      in  forall (v :: * -> *) a. [N] -> v a -> T v a
fromVectorT [N]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c.
(Vector v, VecElem v a, VecElem v b, VecElem v c) =>
(a -> b -> c) -> v a -> v b -> v c
vZipWith a -> b -> c
f v a
cv v b
cv'

-- Zip three arrays with a function.
{-# INLINE zipWith3T #-}
zipWith3T :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
             ShapeL -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
zipWith3T :: forall (v :: * -> *) a b c d.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
[N] -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
zipWith3T [N]
_ a -> b -> c -> d
f (T [N]
ss N
_ v a
v) (T [N]
_ N
_ v b
v') (T [N]
_ N
_ v c
v'') |
  -- If all vectors have length 1, then it's a degenerate case and it's better
  -- to operate on the single element directly.
  forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N
vLength v a
v forall a. Eq a => a -> a -> Bool
== N
1, forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N
vLength v b
v' forall a. Eq a => a -> a -> Bool
== N
1, forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N
vLength v c
v'' forall a. Eq a => a -> a -> Bool
== N
1 =
    forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [N]
ss N
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton forall a b. (a -> b) -> a -> b
$ a -> b -> c -> d
f (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N -> a
vIndex v a
v N
0) (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N -> a
vIndex v b
v' N
0) (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N -> a
vIndex v c
v'' N
0)
zipWith3T [N]
sh a -> b -> c -> d
f T v a
t T v b
t' T v c
t'' = forall (v :: * -> *) a. [N] -> v a -> T v a
fromVectorT [N]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c d.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
(a -> b -> c -> d) -> v a -> v b -> v c -> v d
vZipWith3 a -> b -> c -> d
f v a
v v b
v' v c
v''
  where v :: v a
v   = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v a
t
        v' :: v b
v'  = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v b
t'
        v'' :: v c
v'' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v c
t''

-- Zip four arrays with a function.
{-# INLINE zipWith4T #-}
zipWith4T :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e) => ShapeL -> (a -> b -> c -> d -> e) -> T v a -> T v b -> T v c -> T v d -> T v e
zipWith4T :: forall (v :: * -> *) a b c d e.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e) =>
[N]
-> (a -> b -> c -> d -> e)
-> T v a
-> T v b
-> T v c
-> T v d
-> T v e
zipWith4T [N]
sh a -> b -> c -> d -> e
f T v a
t T v b
t' T v c
t'' T v d
t''' = forall (v :: * -> *) a. [N] -> v a -> T v a
fromVectorT [N]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c d e.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e) =>
(a -> b -> c -> d -> e) -> v a -> v b -> v c -> v d -> v e
vZipWith4 a -> b -> c -> d -> e
f v a
v v b
v' v c
v'' v d
v'''
  where v :: v a
v   = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v a
t
        v' :: v b
v'  = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v b
t'
        v'' :: v c
v'' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v c
t''
        v''' :: v d
v'''= forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v d
t'''

-- Zip five arrays with a function.
{-# INLINE zipWith5T #-}
zipWith5T :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e, VecElem v f) => ShapeL -> (a -> b -> c -> d -> e -> f) -> T v a -> T v b -> T v c -> T v d -> T v e -> T v f
zipWith5T :: forall (v :: * -> *) a b c d e f.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e, VecElem v f) =>
[N]
-> (a -> b -> c -> d -> e -> f)
-> T v a
-> T v b
-> T v c
-> T v d
-> T v e
-> T v f
zipWith5T [N]
sh a -> b -> c -> d -> e -> f
f T v a
t T v b
t' T v c
t'' T v d
t''' T v e
t'''' = forall (v :: * -> *) a. [N] -> v a -> T v a
fromVectorT [N]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c d e f.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e, VecElem v f) =>
(a -> b -> c -> d -> e -> f)
-> v a -> v b -> v c -> v d -> v e -> v f
vZipWith5 a -> b -> c -> d -> e -> f
f v a
v v b
v' v c
v'' v d
v''' v e
v''''
  where v :: v a
v   = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v a
t
        v' :: v b
v'  = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v b
t'
        v'' :: v c
v'' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v c
t''
        v''' :: v d
v'''= forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v d
t'''
        v'''' :: v e
v''''= forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v e
t''''

-- Do an arbitrary transposition.  The first argument should be
-- a permutation of the dimension, i.e., the numbers [0..r-1] in some order
-- (where r is the rank of the array).
{-# INLINE transposeT #-}
transposeT :: [Int] -> T v a -> T v a
transposeT :: forall (v :: * -> *) a. [N] -> T v a -> T v a
transposeT [N]
is (T [N]
ss N
o v a
v) = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T (forall a. [N] -> [a] -> [a]
permute [N]
is [N]
ss) N
o v a
v

-- Return all subarrays n dimensions down.
-- The shape argument should be a prefix of the array shape.
{-# INLINE subArraysT #-}
subArraysT :: ShapeL -> T v a -> [T v a]
subArraysT :: forall (v :: * -> *) a. [N] -> T v a -> [T v a]
subArraysT [N]
sh T v a
ten = forall {v :: * -> *} {a}. [N] -> T v a -> [T v a] -> [T v a]
sub [N]
sh T v a
ten []
  where sub :: [N] -> T v a -> [T v a] -> [T v a]
sub [] T v a
t = (T v a
t forall a. a -> [a] -> [a]
:)
        sub (N
n:[N]
ns) T v a
t = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. a -> a
id [[N] -> T v a -> [T v a] -> [T v a]
sub [N]
ns (forall (v :: * -> *) a. T v a -> N -> T v a
indexT T v a
t N
i) | N
i <- [N
0..N
nforall a. Num a => a -> a -> a
-N
1]]

-- Reverse the given dimensions.
{-# INLINE reverseT #-}
reverseT :: [N] -> ShapeL -> T v a -> T v a
reverseT :: forall (v :: * -> *) a. [N] -> [N] -> T v a -> T v a
reverseT [N]
rs [N]
sh (T [N]
ats N
ao v a
v) = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [N]
rts N
ro v a
v
  where (N
ro, [N]
rts) = N -> [N] -> [N] -> (N, [N])
rev N
0 [N]
sh [N]
ats
        rev :: N -> [N] -> [N] -> (N, [N])
rev !N
_ [] [] = (N
ao, [])
        rev N
r (N
m:[N]
ms) (N
t:[N]
ts) | N
r forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [N]
rs = (N
o forall a. Num a => a -> a -> a
+ (N
mforall a. Num a => a -> a -> a
-N
1)forall a. Num a => a -> a -> a
*N
t, -N
t forall a. a -> [a] -> [a]
: [N]
ts')
                            | Bool
otherwise   = (N
o,            N
t forall a. a -> [a] -> [a]
: [N]
ts')
          where (N
o, [N]
ts') = N -> [N] -> [N] -> (N, [N])
rev (N
rforall a. Num a => a -> a -> a
+N
1) [N]
ms [N]
ts
        rev N
_ [N]
_ [N]
_ = forall a. HasCallStack => String -> a
error String
"reverseT: impossible"

-- Reduction of all array elements.
{-# INLINE reduceT #-}
reduceT :: (Vector v, VecElem v a) =>
           ShapeL -> (a -> a -> a) -> a -> T v a -> T v a
reduceT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> (a -> a -> a) -> a -> T v a -> T v a
reduceT [N]
sh a -> a -> a
f a
z = forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> T v a
scalarT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> a -> a) -> a -> v a -> a
vFold a -> a -> a
f a
z forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh

-- Right fold via toListT.
{-# INLINE foldrT #-}
foldrT
  :: (Vector v, VecElem v a) => ShapeL -> (a -> b -> b) -> b -> T v a -> b
foldrT :: forall (v :: * -> *) a b.
(Vector v, VecElem v a) =>
[N] -> (a -> b -> b) -> b -> T v a -> b
foldrT [N]
sh a -> b -> b
f b
z T v a
a = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
z (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> [a]
toListT [N]
sh T v a
a)

-- Traversal via toListT/fromListT.
{-# INLINE traverseT #-}
traverseT
  :: (Vector v, VecElem v a, VecElem v b, Applicative f)
  => ShapeL -> (a -> f b) -> T v a -> f (T v b)
traverseT :: forall (v :: * -> *) a b (f :: * -> *).
(Vector v, VecElem v a, VecElem v b, Applicative f) =>
[N] -> (a -> f b) -> T v a -> f (T v b)
traverseT [N]
sh a -> f b
f T v a
a = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> [a] -> T v a
fromListT [N]
sh) (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> [a]
toListT [N]
sh T v a
a))

-- Fast check if all elements are equal.
allSameT :: (Vector v, VecElem v a, Eq a) => ShapeL -> T v a -> Bool
allSameT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a) =>
[N] -> T v a -> Bool
allSameT [N]
sh t :: T v a
t@(T [N]
_ N
_ v a
v)
  | forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N
vLength v a
v forall a. Ord a => a -> a -> Bool
<= N
1 = Bool
True
  | Bool
otherwise =
    let !v' :: v a
v' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v a
t
        !x :: a
x = forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> N -> a
vIndex v a
v' N
0
    in  forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> Bool) -> v a -> Bool
vAll (a
x forall a. Eq a => a -> a -> Bool
==) v a
v'

ppT
  :: (Vector v, VecElem v a, Pretty a)
  => PrettyLevel -> Rational -> ShapeL -> T v a -> Doc
ppT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Pretty a) =>
PrettyLevel -> Rational -> [N] -> T v a -> Doc
ppT PrettyLevel
l Rational
p [N]
sh = Bool -> Doc -> Doc
maybeParens (Rational
p forall a. Ord a => a -> a -> Bool
> Rational
10) forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Doc] -> Doc
vcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map String -> Doc
text forall b c a. (b -> c) -> (a -> b) -> a -> c
.  BoxMode -> String -> [String]
box BoxMode
prettyBoxMode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> String) -> [N] -> T v a -> String
ppT_ (forall a. Pretty a => PrettyLevel -> a -> String
prettyShowL PrettyLevel
l) [N]
sh

ppT_
  :: (Vector v, VecElem v a)
  => (a -> String) -> ShapeL -> T v a -> String
ppT_ :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> String) -> [N] -> T v a -> String
ppT_ a -> String
show_ [N]
sh T v a
t = forall a. (a -> Bool) -> [a] -> [a]
revDropWhile (forall a. Eq a => a -> a -> Bool
== Char
'\n') forall a b. (a -> b) -> a -> b
$ [N] -> T [] String -> ShowS
showsT [N]
sh T [] String
t' String
""
  where ss :: [String]
ss = forall a b. (a -> b) -> [a] -> [b]
map a -> String
show_ forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> [a]
toListT [N]
sh T v a
t
        n :: N
n = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> N
length [String]
ss
        ss' :: [String]
ss' = forall a b. (a -> b) -> [a] -> [b]
map ShowS
padSP [String]
ss
        padSP :: ShowS
padSP String
s = forall a. N -> a -> [a]
replicate (N
n forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> N
length String
s) Char
' ' forall a. [a] -> [a] -> [a]
++ String
s
        t' :: T [] String
        t' :: T [] String
t' = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T (forall a. [a] -> [a]
tail ([N] -> [N]
getStridesT [N]
sh)) N
0 [String]
ss'

showsT :: [N] -> T [] String -> ShowS
showsT :: [N] -> T [] String -> ShowS
showsT (N
0:[N]
_)  T [] String
_ = String -> ShowS
showString String
"EMPTY"
showsT []     T [] String
t = String -> ShowS
showString forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT T [] String
t
showsT s :: [N]
s@[N
_]  T [] String
t = String -> ShowS
showString forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> [a]
toListT [N]
s T [] String
t
showsT (N
n:[N]
ns) T [] String
t =
    forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. a -> a
id [ [N] -> T [] String -> ShowS
showsT [N]
ns (forall (v :: * -> *) a. T v a -> N -> T v a
indexT T [] String
t N
i) forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"\n" | N
i <- [N
0..N
nforall a. Num a => a -> a -> a
-N
1] ]

data BoxMode = BoxMode { BoxMode -> Bool
_bmBars, BoxMode -> Bool
_bmUnicode, BoxMode -> Bool
_bmHeader :: Bool }

prettyBoxMode :: BoxMode
prettyBoxMode :: BoxMode
prettyBoxMode = Bool -> Bool -> Bool -> BoxMode
BoxMode Bool
False Bool
False Bool
False

box :: BoxMode -> String -> [String]
box :: BoxMode -> String -> [String]
box BoxMode{Bool
_bmHeader :: Bool
_bmUnicode :: Bool
_bmBars :: Bool
_bmHeader :: BoxMode -> Bool
_bmUnicode :: BoxMode -> Bool
_bmBars :: BoxMode -> Bool
..} String
s =
  let bar :: Char
bar | Bool
_bmUnicode = Char
'\x2502'
          | Bool
otherwise = Char
'|'
      ls :: [String]
ls = String -> [String]
lines String
s
      ls' :: [String]
ls' | Bool
_bmBars = forall a b. (a -> b) -> [a] -> [b]
map (\ String
l -> if forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
l then String
l else [Char
bar] forall a. [a] -> [a] -> [a]
++ String
l forall a. [a] -> [a] -> [a]
++ [Char
bar]) [String]
ls
          | Bool
otherwise = [String]
ls
      h :: String
h = String
"+" forall a. [a] -> [a] -> [a]
++ forall a. N -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> N
length (forall a. [a] -> a
head [String]
ls)) Char
'-' forall a. [a] -> [a] -> [a]
++ String
"+"
      ls'' :: [String]
ls'' | Bool
_bmHeader = [String
h] forall a. [a] -> [a] -> [a]
++ [String]
ls' forall a. [a] -> [a] -> [a]
++ [String
h]
           | Bool
otherwise = [String]
ls'
  in  [String]
ls''

zipWithLong2 :: (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 :: forall a b. (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 a -> b -> b
f (a
a:[a]
as) (b
b:[b]
bs) = a -> b -> b
f a
a b
b forall a. a -> [a] -> [a]
: forall a b. (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 a -> b -> b
f [a]
as [b]
bs
zipWithLong2 a -> b -> b
_     [a]
_     [b]
bs  = [b]
bs

padT :: forall v a . (Vector v, VecElem v a) => a -> [(Int, Int)] -> ShapeL -> T v a -> ([Int], T v a)
padT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
a -> [(N, N)] -> [N] -> T v a -> ([N], T v a)
padT a
v [(N, N)]
aps [N]
ash T v a
at = ([N]
ss, forall (v :: * -> *) a. [N] -> v a -> T v a
fromVectorT [N]
ss forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall a b. (a -> b) -> a -> b
$ [(N, N)] -> [N] -> [N] -> T v a -> [v a]
pad' [(N, N)]
aps [N]
ash [N]
st T v a
at)
  where pad' :: [(Int, Int)] -> ShapeL -> [Int] -> T v a -> [v a]
        pad' :: [(N, N)] -> [N] -> [N] -> T v a -> [v a]
pad' [] [N]
sh [N]
_ T v a
t = [forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v a
t]
        pad' ((N
l,N
h):[(N, N)]
ps) (N
s:[N]
sh) (N
n:[N]
ns) T v a
t =
          [forall (v :: * -> *) a. (Vector v, VecElem v a) => N -> a -> v a
vReplicate (N
nforall a. Num a => a -> a -> a
*N
l) a
v] forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([(N, N)] -> [N] -> [N] -> T v a -> [v a]
pad' [(N, N)]
ps [N]
sh [N]
ns forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. T v a -> N -> T v a
indexT T v a
t) [N
0..N
sforall a. Num a => a -> a -> a
-N
1] forall a. [a] -> [a] -> [a]
++ [forall (v :: * -> *) a. (Vector v, VecElem v a) => N -> a -> v a
vReplicate (N
nforall a. Num a => a -> a -> a
*N
h) a
v]
        pad' [(N, N)]
_ [N]
_ [N]
_ T v a
_ = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"pad: rank mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> N
length [(N, N)]
aps, forall (t :: * -> *) a. Foldable t => t a -> N
length [N]
ash)
        N
_ : [N]
st = [N] -> [N]
getStridesT [N]
ss
        ss :: [N]
ss = forall a b. (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 (\ (N
l,N
h) N
s -> N
lforall a. Num a => a -> a -> a
+N
sforall a. Num a => a -> a -> a
+N
h) [(N, N)]
aps [N]
ash

-- Check if a reshape is just adding/removing some dimensions of
-- size 1, in which case it can be done by just manipulating
-- the strides.  Given the old strides, the old shapes, and the
-- new shape it will return the possible new strides.
simpleReshape :: [N] -> ShapeL -> ShapeL -> Maybe [N]
simpleReshape :: [N] -> [N] -> [N] -> Maybe [N]
simpleReshape [N]
osts [N]
os [N]
ns
  | forall a. (a -> Bool) -> [a] -> [a]
filter (N
1 forall a. Eq a => a -> a -> Bool
/=) [N]
os forall a. Eq a => a -> a -> Bool
== forall a. (a -> Bool) -> [a] -> [a]
filter (N
1 forall a. Eq a => a -> a -> Bool
/=) [N]
ns = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [N] -> [N] -> [N]
loop [N]
ns [N]
sts'
    -- Old and new dimensions agree where they are not 1.
    where
      -- Get old strides for non-1 dimensions
      sts' :: [N]
sts' = [ N
st | (N
st, N
s) <- forall a b. [a] -> [b] -> [(a, b)]
zip [N]
osts [N]
os, N
s forall a. Eq a => a -> a -> Bool
/= N
1 ]
      -- Insert stride 0 for all 1 dimensions in new shape.
      loop :: [N] -> [N] -> [N]
loop [] [] = []
      loop (N
1:[N]
ss)     [N]
sts  = N
0  forall a. a -> [a] -> [a]
: [N] -> [N] -> [N]
loop [N]
ss [N]
sts
      loop (N
_:[N]
ss) (N
st:[N]
sts) = N
st forall a. a -> [a] -> [a]
: [N] -> [N] -> [N]
loop [N]
ss [N]
sts
      loop [N]
_ [N]
_ = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"simpleReshape: shouldn't happen: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ([N]
osts, [N]
os, [N]
ns)
simpleReshape [N]
_ [N]
_ [N]
_ = forall a. Maybe a
Nothing

{-# INLINE sumT #-}
sumT :: (Vector v, VecElem v a, Num a) => ShapeL -> T v a -> a
sumT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
[N] -> T v a -> a
sumT [N]
sh = forall (v :: * -> *) a. (Vector v, VecElem v a, Num a) => v a -> a
vSum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toUnorderedVectorT [N]
sh

{-# INLINE productT #-}
productT :: (Vector v, VecElem v a, Num a) => ShapeL -> T v a -> a
productT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
[N] -> T v a -> a
productT [N]
sh = forall (v :: * -> *) a. (Vector v, VecElem v a, Num a) => v a -> a
vProduct forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toUnorderedVectorT [N]
sh

{-# INLINE maximumT #-}
maximumT :: (Vector v, VecElem v a, Ord a) => ShapeL -> T v a -> a
maximumT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
[N] -> T v a -> a
maximumT [N]
sh = forall (v :: * -> *) a. (Vector v, VecElem v a, Ord a) => v a -> a
vMaximum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toUnorderedVectorT [N]
sh

{-# INLINE minimumT #-}
minimumT :: (Vector v, VecElem v a, Ord a) => ShapeL -> T v a -> a
minimumT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
[N] -> T v a -> a
minimumT [N]
sh = forall (v :: * -> *) a. (Vector v, VecElem v a, Ord a) => v a -> a
vMinimum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toUnorderedVectorT [N]
sh

{-# INLINE anyT #-}
anyT :: (Vector v, VecElem v a) => ShapeL -> (a -> Bool) -> T v a -> Bool
anyT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> (a -> Bool) -> T v a -> Bool
anyT [N]
sh a -> Bool
p = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> Bool) -> v a -> Bool
vAny a -> Bool
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toUnorderedVectorT [N]
sh

{-# INLINE allT #-}
allT :: (Vector v, VecElem v a) => ShapeL -> (a -> Bool) -> T v a -> Bool
allT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> (a -> Bool) -> T v a -> Bool
allT [N]
sh a -> Bool
p = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> Bool) -> v a -> Bool
vAll a -> Bool
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toUnorderedVectorT [N]
sh

{-# INLINE updateT #-}
updateT :: (Vector v, VecElem v a) => ShapeL -> T v a -> [([Int], a)] -> T v a
updateT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> [([N], a)] -> T v a
updateT [N]
sh T v a
t [([N], a)]
us = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [N]
ss N
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
v a -> [(N, a)] -> v a
vUpdate (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> T v a -> v a
toVectorT [N]
sh T v a
t) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map ([N], a) -> (N, a)
ix [([N], a)]
us
  where N
_ : [N]
ss = [N] -> [N]
getStridesT [N]
sh
        ix :: ([N], a) -> (N, a)
ix ([N]
is, a
a) = (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [N]
is [N]
ss, a
a)

{-# INLINE generateT #-}
generateT :: (Vector v, VecElem v a) => ShapeL -> ([Int] -> a) -> T v a
generateT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> ([N] -> a) -> T v a
generateT [N]
sh [N] -> a
f = forall (v :: * -> *) a. [N] -> N -> v a -> T v a
T [N]
ss N
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
N -> (N -> a) -> v a
vGenerate N
s N -> a
g
  where N
s : [N]
ss = [N] -> [N]
getStridesT [N]
sh
        g :: N -> a
g N
i = [N] -> a
f (forall {t}. Integral t => [t] -> t -> [t]
toIx [N]
ss N
i)
        toIx :: [t] -> t -> [t]
toIx [] t
_ = []
        toIx (t
n:[t]
ns) t
i = t
q forall a. a -> [a] -> [a]
: [t] -> t -> [t]
toIx [t]
ns t
r where (t
q, t
r) = forall a. Integral a => a -> a -> (a, a)
quotRem t
i t
n

{-# INLINE iterateNT #-}
iterateNT :: (Vector v, VecElem v a) => Int -> (a -> a) -> a -> T v a
iterateNT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
N -> (a -> a) -> a -> T v a
iterateNT N
n a -> a
f a
x = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> [a] -> T v a
fromListT [N
n] forall a b. (a -> b) -> a -> b
$ forall a. N -> [a] -> [a]
take N
n forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate a -> a
f a
x

{-# INLINE iotaT #-}
iotaT :: (Vector v, VecElem v a, Enum a, Num a) => Int -> T v a
iotaT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Enum a, Num a) =>
N -> T v a
iotaT N
n = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[N] -> [a] -> T v a
fromListT [N
n] [a
0 .. forall a b. (Integral a, Num b) => a -> b
fromIntegral N
n forall a. Num a => a -> a -> a
- a
1]    -- TODO: should use V.enumFromTo instead

-------

-- | Permute the elements of a list, the first argument is indices into the original list.
permute :: [Int] -> [a] -> [a]
permute :: forall a. [N] -> [a] -> [a]
permute [N]
is [a]
xs = forall a b. (a -> b) -> [a] -> [b]
map ([a]
xsforall a. [a] -> N -> a
!!) [N]
is

-- | Like 'dropWhile' but at the end of the list.
revDropWhile :: (a -> Bool) -> [a] -> [a]
revDropWhile :: forall a. (a -> Bool) -> [a] -> [a]
revDropWhile a -> Bool
p = forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
dropWhile a -> Bool
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse

allSame :: (Eq a) => [a] -> Bool
allSame :: forall a. Eq a => [a] -> Bool
allSame [] = Bool
True
allSame (a
x : [a]
xs) = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a
x forall a. Eq a => a -> a -> Bool
==) [a]
xs

-- | Get the value of a type level Nat.
-- Use with explicit type application, i.e., @valueOf \@42@
{-# INLINE valueOf #-}
valueOf :: forall n i . (KnownNat n, Num i) => i
valueOf :: forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf = forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy :: Proxy n)