{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}

-- | Arrays with a fixed shape.
module NumHask.Array.Fixed
  ( -- $usage
    Array (..),

    -- * Conversion
    with,
    shape,
    toDynamic,

    -- * Operators
    reshape,
    transpose,
    diag,
    ident,
    singleton,
    selects,
    selectsExcept,
    folds,
    extracts,
    extractsExcept,
    joins,
    maps,
    concatenate,
    insert,
    append,
    reorder,
    expand,
    apply,
    contract,
    dot,
    mult,
    slice,
    squeeze,

    -- * Scalar

    --
    -- Scalar specialisations
    Scalar,
    fromScalar,
    toScalar,

    -- * Vector

    --
    -- Vector specialisations.
    Vector,

    -- * Matrix

    --
    -- Matrix specialisations.
    Matrix,
    col,
    row,
    safeCol,
    safeRow,
    mmult,
  )
where

import Data.Distributive (Distributive (..))
import Data.Functor.Rep
import Data.List ((!!))
import Data.Proxy
import qualified Data.Vector as V
import GHC.Exts (IsList (..))
import GHC.Show (Show (..))
import GHC.TypeLits
import qualified NumHask.Array.Dynamic as D
import NumHask.Array.Shape
import NumHask.Prelude as P hiding (toList)

-- $setup
--
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> :set -XFlexibleContexts
-- >>> :set -XRebindableSyntax
-- >>> import NumHask.Prelude
-- >>> import GHC.TypeLits (Nat)
-- >>> import Data.Proxy
-- >>> import Data.Functor.Rep
-- >>> let s = [1] :: Array ('[] :: [Nat]) Int -- scalar
-- >>> let v = [1,2,3] :: Array '[3] Int       -- vector
-- >>> let m = [0..11] :: Array '[3,4] Int     -- matrix
-- >>> let a = [1..24] :: Array '[2,3,4] Int

-- $usage
--
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> :set -XFlexibleContexts
-- >>> :set -XRebindableSyntax
-- >>> import NumHask.Prelude
-- >>> import NumHask.Array.Fixed
-- >>> import GHC.TypeLits (Nat)
-- >>> let s = [1] :: Array ('[] :: [Nat]) Int -- scalar
-- >>> let v = [1,2,3] :: Array '[3] Int       -- vector
-- >>> let m = [0..11] :: Array '[3,4] Int     -- matrix
-- >>> let a = [1..24] :: Array '[2,3,4] Int

-- | a multidimensional array with a type-level shape
--
-- >>> :set -XDataKinds
-- >>> [1..24] :: Array '[2,3,4] Int
-- [[[1, 2, 3, 4],
--   [5, 6, 7, 8],
--   [9, 10, 11, 12]],
--  [[13, 14, 15, 16],
--   [17, 18, 19, 20],
--   [21, 22, 23, 24]]]
--
-- >>> [1,2,3] :: Array '[2,2] Int
-- *** Exception: NumHaskException {errorMessage = "shape mismatch"}
newtype Array s a = Array {Array s a -> Vector a
unArray :: V.Vector a} deriving (Array s a -> Array s a -> Bool
(Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Bool) -> Eq (Array s a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (s :: k) a. Eq a => Array s a -> Array s a -> Bool
/= :: Array s a -> Array s a -> Bool
$c/= :: forall k (s :: k) a. Eq a => Array s a -> Array s a -> Bool
== :: Array s a -> Array s a -> Bool
$c== :: forall k (s :: k) a. Eq a => Array s a -> Array s a -> Bool
Eq, Eq (Array s a)
Eq (Array s a)
-> (Array s a -> Array s a -> Ordering)
-> (Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Bool)
-> (Array s a -> Array s a -> Array s a)
-> (Array s a -> Array s a -> Array s a)
-> Ord (Array s a)
Array s a -> Array s a -> Bool
Array s a -> Array s a -> Ordering
Array s a -> Array s a -> Array s a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (s :: k) a. Ord a => Eq (Array s a)
forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
forall k (s :: k) a. Ord a => Array s a -> Array s a -> Ordering
forall k (s :: k) a. Ord a => Array s a -> Array s a -> Array s a
min :: Array s a -> Array s a -> Array s a
$cmin :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Array s a
max :: Array s a -> Array s a -> Array s a
$cmax :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Array s a
>= :: Array s a -> Array s a -> Bool
$c>= :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
> :: Array s a -> Array s a -> Bool
$c> :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
<= :: Array s a -> Array s a -> Bool
$c<= :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
< :: Array s a -> Array s a -> Bool
$c< :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Bool
compare :: Array s a -> Array s a -> Ordering
$ccompare :: forall k (s :: k) a. Ord a => Array s a -> Array s a -> Ordering
$cp1Ord :: forall k (s :: k) a. Ord a => Eq (Array s a)
Ord, a -> Array s b -> Array s a
(a -> b) -> Array s a -> Array s b
(forall a b. (a -> b) -> Array s a -> Array s b)
-> (forall a b. a -> Array s b -> Array s a) -> Functor (Array s)
forall k (s :: k) a b. a -> Array s b -> Array s a
forall k (s :: k) a b. (a -> b) -> Array s a -> Array s b
forall a b. a -> Array s b -> Array s a
forall a b. (a -> b) -> Array s a -> Array s b
forall (f :: Type -> Type).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Array s b -> Array s a
$c<$ :: forall k (s :: k) a b. a -> Array s b -> Array s a
fmap :: (a -> b) -> Array s a -> Array s b
$cfmap :: forall k (s :: k) a b. (a -> b) -> Array s a -> Array s b
Functor, a -> Array s a -> Bool
Array s m -> m
Array s a -> [a]
Array s a -> Bool
Array s a -> Int
Array s a -> a
Array s a -> a
Array s a -> a
Array s a -> a
(a -> m) -> Array s a -> m
(a -> m) -> Array s a -> m
(a -> b -> b) -> b -> Array s a -> b
(a -> b -> b) -> b -> Array s a -> b
(b -> a -> b) -> b -> Array s a -> b
(b -> a -> b) -> b -> Array s a -> b
(a -> a -> a) -> Array s a -> a
(a -> a -> a) -> Array s a -> a
(forall m. Monoid m => Array s m -> m)
-> (forall m a. Monoid m => (a -> m) -> Array s a -> m)
-> (forall m a. Monoid m => (a -> m) -> Array s a -> m)
-> (forall a b. (a -> b -> b) -> b -> Array s a -> b)
-> (forall a b. (a -> b -> b) -> b -> Array s a -> b)
-> (forall b a. (b -> a -> b) -> b -> Array s a -> b)
-> (forall b a. (b -> a -> b) -> b -> Array s a -> b)
-> (forall a. (a -> a -> a) -> Array s a -> a)
-> (forall a. (a -> a -> a) -> Array s a -> a)
-> (forall a. Array s a -> [a])
-> (forall a. Array s a -> Bool)
-> (forall a. Array s a -> Int)
-> (forall a. Eq a => a -> Array s a -> Bool)
-> (forall a. Ord a => Array s a -> a)
-> (forall a. Ord a => Array s a -> a)
-> (forall a. Num a => Array s a -> a)
-> (forall a. Num a => Array s a -> a)
-> Foldable (Array s)
forall a. Eq a => a -> Array s a -> Bool
forall a. Num a => Array s a -> a
forall a. Ord a => Array s a -> a
forall m. Monoid m => Array s m -> m
forall a. Array s a -> Bool
forall a. Array s a -> Int
forall a. Array s a -> [a]
forall a. (a -> a -> a) -> Array s a -> a
forall k (s :: k) a. Eq a => a -> Array s a -> Bool
forall k (s :: k) a. Num a => Array s a -> a
forall k (s :: k) a. Ord a => Array s a -> a
forall k (s :: k) m. Monoid m => Array s m -> m
forall k (s :: k) a. Array s a -> Bool
forall k (s :: k) a. Array s a -> Int
forall k (s :: k) a. Array s a -> [a]
forall k (s :: k) a. (a -> a -> a) -> Array s a -> a
forall k (s :: k) m a. Monoid m => (a -> m) -> Array s a -> m
forall k (s :: k) b a. (b -> a -> b) -> b -> Array s a -> b
forall k (s :: k) a b. (a -> b -> b) -> b -> Array s a -> b
forall m a. Monoid m => (a -> m) -> Array s a -> m
forall b a. (b -> a -> b) -> b -> Array s a -> b
forall a b. (a -> b -> b) -> b -> Array s a -> b
forall (t :: Type -> Type).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: Array s a -> a
$cproduct :: forall k (s :: k) a. Num a => Array s a -> a
sum :: Array s a -> a
$csum :: forall k (s :: k) a. Num a => Array s a -> a
minimum :: Array s a -> a
$cminimum :: forall k (s :: k) a. Ord a => Array s a -> a
maximum :: Array s a -> a
$cmaximum :: forall k (s :: k) a. Ord a => Array s a -> a
elem :: a -> Array s a -> Bool
$celem :: forall k (s :: k) a. Eq a => a -> Array s a -> Bool
length :: Array s a -> Int
$clength :: forall k (s :: k) a. Array s a -> Int
null :: Array s a -> Bool
$cnull :: forall k (s :: k) a. Array s a -> Bool
toList :: Array s a -> [a]
$ctoList :: forall k (s :: k) a. Array s a -> [a]
foldl1 :: (a -> a -> a) -> Array s a -> a
$cfoldl1 :: forall k (s :: k) a. (a -> a -> a) -> Array s a -> a
foldr1 :: (a -> a -> a) -> Array s a -> a
$cfoldr1 :: forall k (s :: k) a. (a -> a -> a) -> Array s a -> a
foldl' :: (b -> a -> b) -> b -> Array s a -> b
$cfoldl' :: forall k (s :: k) b a. (b -> a -> b) -> b -> Array s a -> b
foldl :: (b -> a -> b) -> b -> Array s a -> b
$cfoldl :: forall k (s :: k) b a. (b -> a -> b) -> b -> Array s a -> b
foldr' :: (a -> b -> b) -> b -> Array s a -> b
$cfoldr' :: forall k (s :: k) a b. (a -> b -> b) -> b -> Array s a -> b
foldr :: (a -> b -> b) -> b -> Array s a -> b
$cfoldr :: forall k (s :: k) a b. (a -> b -> b) -> b -> Array s a -> b
foldMap' :: (a -> m) -> Array s a -> m
$cfoldMap' :: forall k (s :: k) m a. Monoid m => (a -> m) -> Array s a -> m
foldMap :: (a -> m) -> Array s a -> m
$cfoldMap :: forall k (s :: k) m a. Monoid m => (a -> m) -> Array s a -> m
fold :: Array s m -> m
$cfold :: forall k (s :: k) m. Monoid m => Array s m -> m
Foldable, (forall x. Array s a -> Rep (Array s a) x)
-> (forall x. Rep (Array s a) x -> Array s a)
-> Generic (Array s a)
forall x. Rep (Array s a) x -> Array s a
forall x. Array s a -> Rep (Array s a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall k (s :: k) a x. Rep (Array s a) x -> Array s a
forall k (s :: k) a x. Array s a -> Rep (Array s a) x
$cto :: forall k (s :: k) a x. Rep (Array s a) x -> Array s a
$cfrom :: forall k (s :: k) a x. Array s a -> Rep (Array s a) x
Generic, Functor (Array s)
Foldable (Array s)
Functor (Array s)
-> Foldable (Array s)
-> (forall (f :: Type -> Type) a b.
    Applicative f =>
    (a -> f b) -> Array s a -> f (Array s b))
-> (forall (f :: Type -> Type) a.
    Applicative f =>
    Array s (f a) -> f (Array s a))
-> (forall (m :: Type -> Type) a b.
    Monad m =>
    (a -> m b) -> Array s a -> m (Array s b))
-> (forall (m :: Type -> Type) a.
    Monad m =>
    Array s (m a) -> m (Array s a))
-> Traversable (Array s)
(a -> f b) -> Array s a -> f (Array s b)
forall k (s :: k). Functor (Array s)
forall k (s :: k). Foldable (Array s)
forall k (s :: k) (m :: Type -> Type) a.
Monad m =>
Array s (m a) -> m (Array s a)
forall k (s :: k) (f :: Type -> Type) a.
Applicative f =>
Array s (f a) -> f (Array s a)
forall k (s :: k) (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
forall k (s :: k) (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
forall (t :: Type -> Type).
Functor t
-> Foldable t
-> (forall (f :: Type -> Type) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: Type -> Type) a.
    Applicative f =>
    t (f a) -> f (t a))
-> (forall (m :: Type -> Type) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: Type -> Type) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: Type -> Type) a.
Monad m =>
Array s (m a) -> m (Array s a)
forall (f :: Type -> Type) a.
Applicative f =>
Array s (f a) -> f (Array s a)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
sequence :: Array s (m a) -> m (Array s a)
$csequence :: forall k (s :: k) (m :: Type -> Type) a.
Monad m =>
Array s (m a) -> m (Array s a)
mapM :: (a -> m b) -> Array s a -> m (Array s b)
$cmapM :: forall k (s :: k) (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Array s a -> m (Array s b)
sequenceA :: Array s (f a) -> f (Array s a)
$csequenceA :: forall k (s :: k) (f :: Type -> Type) a.
Applicative f =>
Array s (f a) -> f (Array s a)
traverse :: (a -> f b) -> Array s a -> f (Array s b)
$ctraverse :: forall k (s :: k) (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> Array s a -> f (Array s b)
$cp2Traversable :: forall k (s :: k). Foldable (Array s)
$cp1Traversable :: forall k (s :: k). Functor (Array s)
Traversable)

instance (HasShape s, Show a) => Show (Array s a) where
  show :: Array s a -> String
show Array s a
a = Array a -> String
forall a. Show a => a -> String
GHC.Show.show (Array s a -> Array a
forall (s :: [Nat]) a. HasShape s => Array s a -> Array a
toDynamic Array s a
a)

instance
  ( HasShape s
  ) =>
  Data.Distributive.Distributive (Array s)
  where
  distribute :: f (Array s a) -> Array s (f a)
distribute = f (Array s a) -> Array s (f a)
forall (f :: Type -> Type) (w :: Type -> Type) a.
(Representable f, Functor w) =>
w (f a) -> f (w a)
distributeRep
  {-# INLINE distribute #-}

instance
  forall s.
  ( HasShape s
  ) =>
  Representable (Array s)
  where
  type Rep (Array s) = [Int]

  tabulate :: (Rep (Array s) -> a) -> Array s a
tabulate Rep (Array s) -> a
f =
    Vector a -> Array s a
forall k (s :: k) a. Vector a -> Array s a
Array (Vector a -> Array s a)
-> ((Int -> a) -> Vector a) -> (Int -> a) -> Array s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate ([Int] -> Int
size [Int]
s) ((Int -> a) -> Array s a) -> (Int -> a) -> Array s a
forall a b. (a -> b) -> a -> b
$ ([Int] -> a
Rep (Array s) -> a
f ([Int] -> a) -> (Int -> [Int]) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int -> [Int]
shapen [Int]
s)
    where
      s :: [Int]
s = Shape s -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (Shape s -> [Int]) -> Shape s -> [Int]
forall a b. (a -> b) -> a -> b
$ HasShape s => Shape s
forall (s :: [Nat]). HasShape s => Shape s
toShape @s
  {-# INLINE tabulate #-}

  index :: Array s a -> Rep (Array s) -> a
index (Array Vector a
v) Rep (Array s)
i = Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v ([Int] -> [Int] -> Int
flatten [Int]
s [Int]
Rep (Array s)
i)
    where
      s :: [Int]
s = Shape s -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape s => Shape s
forall (s :: [Nat]). HasShape s => Shape s
toShape @s)
  {-# INLINE index #-}

-- * NumHask heirarchy

instance
  ( Additive a,
    HasShape s
  ) =>
  Additive (Array s a)
  where
  + :: Array s a -> Array s a -> Array s a
(+) = (a -> a -> a) -> Array s a -> Array s a -> Array s a
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
forall a. Additive a => a -> a -> a
(+)

  zero :: Array s a
zero = a -> Array s a
forall (f :: Type -> Type) a. Representable f => a -> f a
pureRep a
forall a. Additive a => a
zero

instance
  ( Subtractive a,
    HasShape s
  ) =>
  Subtractive (Array s a)
  where
  negate :: Array s a -> Array s a
negate = (a -> a) -> Array s a -> Array s a
forall (f :: Type -> Type) a b.
Representable f =>
(a -> b) -> f a -> f b
fmapRep a -> a
forall a. Subtractive a => a -> a
negate

instance
  (HasShape s, Multiplicative a) =>
  MultiplicativeAction (Array s a) a
  where
  .* :: a -> Array s a -> Array s a
(.*) a
s Array s a
r = (a -> a) -> Array s a -> Array s a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
s a -> a -> a
forall a. Multiplicative a => a -> a -> a
*) Array s a
r
  {-# INLINE (.*) #-}

  *. :: Array s a -> a -> Array s a
(*.) Array s a
r a
s = (a -> a) -> Array s a -> Array s a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Multiplicative a => a -> a -> a
* a
s) Array s a
r
  {-# INLINE (*.) #-}

instance
  (HasShape s, Additive a) =>
  AdditiveAction (Array s a) a
  where
  .+ :: a -> Array s a -> Array s a
(.+) a
s Array s a
r = (a -> a) -> Array s a -> Array s a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
s a -> a -> a
forall a. Additive a => a -> a -> a
+) Array s a
r
  {-# INLINE (.+) #-}

  +. :: Array s a -> a -> Array s a
(+.) Array s a
r a
s = (a -> a) -> Array s a -> Array s a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Additive a => a -> a -> a
+ a
s) Array s a
r
  {-# INLINE (+.) #-}

instance
  (HasShape s, Subtractive a) =>
  SubtractiveAction (Array s a) a
  where
  .- :: a -> Array s a -> Array s a
(.-) a
s Array s a
r = (a -> a) -> Array s a -> Array s a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
s a -> a -> a
forall a. Subtractive a => a -> a -> a
-) Array s a
r
  {-# INLINE (.-) #-}

  -. :: Array s a -> a -> Array s a
(-.) Array s a
r a
s = (a -> a) -> Array s a -> Array s a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (\a
x -> a
x a -> a -> a
forall a. Subtractive a => a -> a -> a
- a
s) Array s a
r
  {-# INLINE (-.) #-}

instance
  (HasShape s, Divisive a) =>
  DivisiveAction (Array s a) a
  where
  ./ :: a -> Array s a -> Array s a
(./) a
s Array s a
r = (a -> a) -> Array s a -> Array s a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
s a -> a -> a
forall a. Divisive a => a -> a -> a
/) Array s a
r
  {-# INLINE (./) #-}

  /. :: Array s a -> a -> Array s a
(/.) Array s a
r a
s = (a -> a) -> Array s a -> Array s a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Divisive a => a -> a -> a
/ a
s) Array s a
r
  {-# INLINE (/.) #-}

instance (HasShape s, JoinSemiLattice a) => JoinSemiLattice (Array s a) where
  \/ :: Array s a -> Array s a -> Array s a
(\/) = (a -> a -> a) -> Array s a -> Array s a -> Array s a
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
forall a. JoinSemiLattice a => a -> a -> a
(\/)

instance (HasShape s, MeetSemiLattice a) => MeetSemiLattice (Array s a) where
  /\ :: Array s a -> Array s a -> Array s a
(/\) = (a -> a -> a) -> Array s a -> Array s a -> Array s a
forall (f :: Type -> Type) a b c.
Representable f =>
(a -> b -> c) -> f a -> f b -> f c
liftR2 a -> a -> a
forall a. MeetSemiLattice a => a -> a -> a
(/\)

instance (HasShape s, Subtractive a, Epsilon a) => Epsilon (Array s a) where
  epsilon :: Array s a
epsilon = a -> Array s a
forall (s :: [Nat]) a. HasShape s => a -> Array s a
singleton a
forall a. Epsilon a => a
epsilon

  nearZero :: Array s a -> Bool
nearZero (Array Vector a
a) = (a -> Bool) -> Vector a -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all a -> Bool
forall a. Epsilon a => a -> Bool
nearZero Vector a
a

instance
  ( HasShape s
  ) =>
  IsList (Array s a)
  where
  type Item (Array s a) = a

  fromList :: [Item (Array s a)] -> Array s a
fromList [Item (Array s a)]
l =
    Array s a -> Array s a -> Bool -> Array s a
forall a. a -> a -> Bool -> a
bool
      (NumHaskException -> Array s a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"shape mismatch"))
      (Vector a -> Array s a
forall k (s :: k) a. Vector a -> Array s a
Array (Vector a -> Array s a) -> Vector a -> Array s a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
[Item (Array s a)]
l)
      (([a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [a]
[Item (Array s a)]
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& [Int] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [Int]
ds) Bool -> Bool -> Bool
|| ([a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [a]
[Item (Array s a)]
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
size [Int]
ds))
    where
      ds :: [Int]
ds = Shape s -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape s => Shape s
forall (s :: [Nat]). HasShape s => Shape s
toShape @s)

  toList :: Array s a -> [Item (Array s a)]
toList (Array Vector a
v) = Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
v

-- | Get shape of an Array as a value.
--
-- >>> shape a
-- [2,3,4]
shape :: forall a s. (HasShape s) => Array s a -> [Int]
shape :: Array s a -> [Int]
shape Array s a
_ = Shape s -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (Shape s -> [Int]) -> Shape s -> [Int]
forall a b. (a -> b) -> a -> b
$ HasShape s => Shape s
forall (s :: [Nat]). HasShape s => Shape s
toShape @s
{-# INLINE shape #-}

-- | convert to a dynamic array with shape at the value level.
toDynamic :: (HasShape s) => Array s a -> D.Array a
toDynamic :: Array s a -> Array a
toDynamic Array s a
a = [Int] -> [a] -> Array a
forall a. [Int] -> [a] -> Array a
D.fromFlatList (Array s a -> [Int]
forall a (s :: [Nat]). HasShape s => Array s a -> [Int]
shape Array s a
a) (Array s a -> [Item (Array s a)]
forall l. IsList l => l -> [Item l]
toList Array s a
a)

-- | Use a dynamic array in a fixed context.
--
-- >>> import qualified NumHask.Array.Dynamic as D
-- >>> with (D.fromFlatList [2,3,4] [1..24]) (selects (Proxy :: Proxy '[0,1]) [1,1] :: Array '[2,3,4] Int -> Array '[4] Int)
-- [17, 18, 19, 20]
with ::
  forall a r s.
  (HasShape s) =>
  D.Array a ->
  (Array s a -> r) ->
  r
with :: Array a -> (Array s a -> r) -> r
with (D.Array [Int]
_ Vector a
v) Array s a -> r
f = Array s a -> r
f (Vector a -> Array s a
forall k (s :: k) a. Vector a -> Array s a
Array Vector a
v)

-- | Reshape an array (with the same number of elements).
--
-- >>> reshape a :: Array '[4,3,2] Int
-- [[[1, 2],
--   [3, 4],
--   [5, 6]],
--  [[7, 8],
--   [9, 10],
--   [11, 12]],
--  [[13, 14],
--   [15, 16],
--   [17, 18]],
--  [[19, 20],
--   [21, 22],
--   [23, 24]]]
reshape ::
  forall a s s'.
  ( Size s ~ Size s',
    HasShape s,
    HasShape s'
  ) =>
  Array s a ->
  Array s' a
reshape :: Array s a -> Array s' a
reshape Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> a) -> ([Int] -> [Int]) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int -> [Int]
shapen [Int]
s (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int] -> Int
flatten [Int]
s')
  where
    s :: [Int]
s = Shape s -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape s => Shape s
forall (s :: [Nat]). HasShape s => Shape s
toShape @s)
    s' :: [Int]
s' = Shape s' -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape s' => Shape s'
forall (s :: [Nat]). HasShape s => Shape s
toShape @s')

-- | Reverse indices eg transposes the element A/ijk/ to A/kji/.
--
-- >>> index (transpose a) [1,0,0] == index a [0,0,1]
-- True
transpose :: forall a s. (HasShape s, HasShape (Reverse s)) => Array s a -> Array (Reverse s) a
transpose :: Array s a -> Array (Reverse s) a
transpose Array s a
a = (Rep (Array (Reverse s)) -> a) -> Array (Reverse s) a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> a) -> ([Int] -> [Int]) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. [a] -> [a]
reverse)

-- | The identity array.
--
-- >>> ident :: Array '[3,2] Int
-- [[1, 0],
--  [0, 1],
--  [0, 0]]
ident :: forall a s. (HasShape s, Additive a, Multiplicative a) => Array s a
ident :: Array s a
ident = (Rep (Array s) -> a) -> Array s a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
forall a. Additive a => a
zero a
forall a. Multiplicative a => a
one (Bool -> a) -> ([Int] -> Bool) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Bool
forall a. Eq a => [a] -> Bool
isDiag)
  where
    isDiag :: [a] -> Bool
isDiag [] = Bool
True
    isDiag [a
_] = Bool
True
    isDiag [a
x, a
y] = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y
    isDiag (a
x : a
y : [a]
xs) = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y Bool -> Bool -> Bool
&& [a] -> Bool
isDiag (a
y a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs)

-- | Extract the diagonal of an array.
--
-- >>> diag (ident :: Array '[3,2] Int)
-- [1, 1]
diag ::
  forall a s.
  ( HasShape s,
    HasShape '[Minimum s]
  ) =>
  Array s a ->
  Array '[Minimum s] a
diag :: Array s a -> Array '[Minimum s] a
diag Array s a
a = (Rep (Array '[Minimum s]) -> a) -> Array '[Minimum s] a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> a
Rep (Array '[Minimum s]) -> a
go
  where
    go :: [Int] -> a
go [] = NumHaskException -> a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Rank Underflow")
    go (Int
s' : [Int]
_) = Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate ([Int] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Int]
ds) Int
s')
    ds :: [Int]
ds = Shape s -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape s => Shape s
forall (s :: [Nat]). HasShape s => Shape s
toShape @s)

-- | Create an array composed of a single value.
--
-- >>> singleton one :: Array '[3,2] Int
-- [[1, 1],
--  [1, 1],
--  [1, 1]]
singleton :: (HasShape s) => a -> Array s a
singleton :: a -> Array s a
singleton a
a = (Rep (Array s) -> a) -> Array s a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (a -> [Int] -> a
forall a b. a -> b -> a
const a
a)

-- | Select an array along dimensions.
--
-- >>> let s = selects (Proxy :: Proxy '[0,1]) [1,1] a
-- >>> :t s
-- s :: Array '[4] Int
--
-- >>> s
-- [17, 18, 19, 20]
selects ::
  forall ds s s' a.
  ( HasShape s,
    HasShape ds,
    HasShape s',
    s' ~ DropIndexes s ds
  ) =>
  Proxy ds ->
  [Int] ->
  Array s a ->
  Array s' a
selects :: Proxy ds -> [Int] -> Array s a -> Array s' a
selects Proxy ds
_ [Int]
i Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> a
Rep (Array s') -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
s [Int]
ds [Int]
i)
    ds :: [Int]
ds = Shape ds -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape ds => Shape ds
forall (s :: [Nat]). HasShape s => Shape s
toShape @ds)

-- | Select an index /except/ along specified dimensions.
--
-- >>> let s = selectsExcept (Proxy :: Proxy '[2]) [1,1] a
-- >>> :t s
-- s :: Array '[4] Int
--
-- >>> s
-- [17, 18, 19, 20]
selectsExcept ::
  forall ds s s' a.
  ( HasShape s,
    HasShape ds,
    HasShape s',
    s' ~ TakeIndexes s ds
  ) =>
  Proxy ds ->
  [Int] ->
  Array s a ->
  Array s' a
selectsExcept :: Proxy ds -> [Int] -> Array s a -> Array s' a
selectsExcept Proxy ds
_ [Int]
i Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> a
Rep (Array s') -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
i [Int]
ds [Int]
s)
    ds :: [Int]
ds = Shape ds -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape ds => Shape ds
forall (s :: [Nat]). HasShape s => Shape s
toShape @ds)

-- | Fold along specified dimensions.
--
-- >>> folds sum (Proxy :: Proxy '[1]) a
-- [68, 100, 132]
folds ::
  forall ds st si so a b.
  ( HasShape st,
    HasShape ds,
    HasShape si,
    HasShape so,
    si ~ DropIndexes st ds,
    so ~ TakeIndexes st ds
  ) =>
  (Array si a -> b) ->
  Proxy ds ->
  Array st a ->
  Array so b
folds :: (Array si a -> b) -> Proxy ds -> Array st a -> Array so b
folds Array si a -> b
f Proxy ds
d Array st a
a = (Rep (Array so) -> b) -> Array so b
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> b
Rep (Array so) -> b
go
  where
    go :: [Int] -> b
go [Int]
s = Array si a -> b
f (Proxy ds -> [Int] -> Array st a -> Array si a
forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape ds, HasShape s', s' ~ DropIndexes s ds) =>
Proxy ds -> [Int] -> Array s a -> Array s' a
selects Proxy ds
d [Int]
s Array st a
a)

-- | Extracts dimensions to an outer layer.
--
-- >>> let e = extracts (Proxy :: Proxy '[1,2]) a
-- >>> :t e
-- e :: Array '[3, 4] (Array '[2] Int)
extracts ::
  forall ds st si so a.
  ( HasShape st,
    HasShape ds,
    HasShape si,
    HasShape so,
    si ~ DropIndexes st ds,
    so ~ TakeIndexes st ds
  ) =>
  Proxy ds ->
  Array st a ->
  Array so (Array si a)
extracts :: Proxy ds -> Array st a -> Array so (Array si a)
extracts Proxy ds
d Array st a
a = (Rep (Array so) -> Array si a) -> Array so (Array si a)
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> Array si a
Rep (Array so) -> Array si a
go
  where
    go :: [Int] -> Array si a
go [Int]
s = Proxy ds -> [Int] -> Array st a -> Array si a
forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape ds, HasShape s', s' ~ DropIndexes s ds) =>
Proxy ds -> [Int] -> Array s a -> Array s' a
selects Proxy ds
d [Int]
s Array st a
a

-- | Extracts /except/ dimensions to an outer layer.
--
-- >>> let e = extractsExcept (Proxy :: Proxy '[1,2]) a
-- >>> :t e
-- e :: Array '[2] (Array '[3, 4] Int)
extractsExcept ::
  forall ds st si so a.
  ( HasShape st,
    HasShape ds,
    HasShape si,
    HasShape so,
    so ~ DropIndexes st ds,
    si ~ TakeIndexes st ds
  ) =>
  Proxy ds ->
  Array st a ->
  Array so (Array si a)
extractsExcept :: Proxy ds -> Array st a -> Array so (Array si a)
extractsExcept Proxy ds
d Array st a
a = (Rep (Array so) -> Array si a) -> Array so (Array si a)
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> Array si a
Rep (Array so) -> Array si a
go
  where
    go :: [Int] -> Array si a
go [Int]
s = Proxy ds -> [Int] -> Array st a -> Array si a
forall (ds :: [Nat]) (s :: [Nat]) (s' :: [Nat]) a.
(HasShape s, HasShape ds, HasShape s', s' ~ TakeIndexes s ds) =>
Proxy ds -> [Int] -> Array s a -> Array s' a
selectsExcept Proxy ds
d [Int]
s Array st a
a

-- | Join inner and outer dimension layers.
--
-- >>> let e = extracts (Proxy :: Proxy '[1,0]) a
--
-- >>> :t e
-- e :: Array '[3, 2] (Array '[4] Int)
--
-- >>> let j = joins (Proxy :: Proxy '[1,0]) e
--
-- >>> :t j
-- j :: Array '[2, 3, 4] Int
--
-- >>> a == j
-- True
joins ::
  forall ds si st so a.
  ( HasShape st,
    HasShape ds,
    st ~ AddIndexes si ds so,
    HasShape si,
    HasShape so
  ) =>
  Proxy ds ->
  Array so (Array si a) ->
  Array st a
joins :: Proxy ds -> Array so (Array si a) -> Array st a
joins Proxy ds
_ Array so (Array si a)
a = (Rep (Array st) -> a) -> Array st a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> a
Rep (Array st) -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array si a -> Rep (Array si) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index (Array so (Array si a) -> Rep (Array so) -> Array si a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array so (Array si a)
a ([Int] -> [Int] -> [Int]
takeIndexes [Int]
s [Int]
ds)) ([Int] -> [Int] -> [Int]
dropIndexes [Int]
s [Int]
ds)
    ds :: [Int]
ds = Shape ds -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape ds => Shape ds
forall (s :: [Nat]). HasShape s => Shape s
toShape @ds)

-- | Maps a function along specified dimensions.
--
-- >>> :t maps (transpose) (Proxy :: Proxy '[1]) a
-- maps (transpose) (Proxy :: Proxy '[1]) a :: Array '[4, 3, 2] Int
maps ::
  forall ds st st' si si' so a b.
  ( HasShape st,
    HasShape st',
    HasShape ds,
    HasShape si,
    HasShape si',
    HasShape so,
    si ~ DropIndexes st ds,
    so ~ TakeIndexes st ds,
    st' ~ AddIndexes si' ds so,
    st ~ AddIndexes si ds so
  ) =>
  (Array si a -> Array si' b) ->
  Proxy ds ->
  Array st a ->
  Array st' b
maps :: (Array si a -> Array si' b)
-> Proxy ds -> Array st a -> Array st' b
maps Array si a -> Array si' b
f Proxy ds
d Array st a
a = Proxy ds -> Array so (Array si' b) -> Array st' b
forall (ds :: [Nat]) (si :: [Nat]) (st :: [Nat]) (so :: [Nat]) a.
(HasShape st, HasShape ds, st ~ AddIndexes si ds so, HasShape si,
 HasShape so) =>
Proxy ds -> Array so (Array si a) -> Array st a
joins Proxy ds
d ((Array si a -> Array si' b)
-> Array so (Array si a) -> Array so (Array si' b)
forall (f :: Type -> Type) a b.
Representable f =>
(a -> b) -> f a -> f b
fmapRep Array si a -> Array si' b
f (Proxy ds -> Array st a -> Array so (Array si a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(HasShape st, HasShape ds, HasShape si, HasShape so,
 si ~ DropIndexes st ds, so ~ TakeIndexes st ds) =>
Proxy ds -> Array st a -> Array so (Array si a)
extracts Proxy ds
d Array st a
a))

-- | Concatenate along a dimension.
--
-- >>> :t concatenate (Proxy :: Proxy 1) a a
-- concatenate (Proxy :: Proxy 1) a a :: Array '[2, 6, 4] Int
concatenate ::
  forall a s0 s1 d s.
  ( CheckConcatenate d s0 s1 s,
    Concatenate d s0 s1 ~ s,
    HasShape s0,
    HasShape s1,
    HasShape s,
    KnownNat d
  ) =>
  Proxy d ->
  Array s0 a ->
  Array s1 a ->
  Array s a
concatenate :: Proxy d -> Array s0 a -> Array s1 a -> Array s a
concatenate Proxy d
_ Array s0 a
s0 Array s1 a
s1 = (Rep (Array s) -> a) -> Array s a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> a
Rep (Array s) -> a
go
  where
    go :: [Int] -> a
go [Int]
s =
      a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool
        (Array s0 a -> Rep (Array s0) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s0 a
s0 [Int]
Rep (Array s0)
s)
        ( Array s1 a -> Rep (Array s1) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index
            Array s1 a
s1
            ( [Int] -> Int -> Int -> [Int]
addIndex
                ([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
                Int
d
                (([Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d) Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- ([Int]
ds0 [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d))
            )
        )
        (([Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= ([Int]
ds0 [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d))
    ds0 :: [Int]
ds0 = Shape s0 -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape s0 => Shape s0
forall (s :: [Nat]). HasShape s => Shape s
toShape @s0)
    d :: Int
d = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy d -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @d Proxy d
forall k (t :: k). Proxy t
Proxy

-- | Insert along a dimension at a position.
--
-- >>> insert (Proxy :: Proxy 2) (Proxy :: Proxy 0) a ([100..105])
-- [[[100, 1, 2, 3, 4],
--   [101, 5, 6, 7, 8],
--   [102, 9, 10, 11, 12]],
--  [[103, 13, 14, 15, 16],
--   [104, 17, 18, 19, 20],
--   [105, 21, 22, 23, 24]]]
insert ::
  forall a s s' d i.
  ( DropIndex s d ~ s',
    CheckInsert d i s,
    KnownNat i,
    KnownNat d,
    HasShape s,
    HasShape s',
    HasShape (Insert d s)
  ) =>
  Proxy d ->
  Proxy i ->
  Array s a ->
  Array s' a ->
  Array (Insert d s) a
insert :: Proxy d
-> Proxy i -> Array s a -> Array s' a -> Array (Insert d s) a
insert Proxy d
_ Proxy i
_ Array s a
a Array s' a
b = (Rep (Array (Insert d s)) -> a) -> Array (Insert d s) a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> a
Rep (Array (Insert d s)) -> a
go
  where
    go :: [Int] -> a
go [Int]
s
      | [Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = Array s' a -> Rep (Array s') -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s' a
b ([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
      | [Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a [Int]
Rep (Array s)
s
      | Bool
otherwise = Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a (Int -> [Int] -> [Int]
decAt Int
d [Int]
s)
    d :: Int
d = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy d -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @d Proxy d
forall k (t :: k). Proxy t
Proxy
    i :: Int
i = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy i -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @i Proxy i
forall k (t :: k). Proxy t
Proxy

-- | Insert along a dimension at the end.
--
-- >>>  :t append (Proxy :: Proxy 0) a
-- append (Proxy :: Proxy 0) a
--   :: Array '[3, 4] Int -> Array '[3, 3, 4] Int
append ::
  forall a d s s'.
  ( DropIndex s d ~ s',
    CheckInsert d (Dimension s d - 1) s,
    KnownNat (Dimension s d - 1),
    KnownNat d,
    HasShape s,
    HasShape s',
    HasShape (Insert d s)
  ) =>
  Proxy d ->
  Array s a ->
  Array s' a ->
  Array (Insert d s) a
append :: Proxy d -> Array s a -> Array s' a -> Array (Insert d s) a
append Proxy d
d = Proxy d
-> Proxy (Dimension s d - 1)
-> Array s a
-> Array s' a
-> Array (Insert d s) a
forall a (s :: [Nat]) (s' :: [Nat]) (d :: Nat) (i :: Nat).
(DropIndex s d ~ s', CheckInsert d i s, KnownNat i, KnownNat d,
 HasShape s, HasShape s', HasShape (Insert d s)) =>
Proxy d
-> Proxy i -> Array s a -> Array s' a -> Array (Insert d s) a
insert Proxy d
d (Proxy (Dimension s d - 1)
forall k (t :: k). Proxy t
Proxy :: Proxy (Dimension s d - 1))

-- | Change the order of dimensions.
--
-- >>> let r = reorder (Proxy :: Proxy '[2,0,1]) a
-- >>> :t r
-- r :: Array '[4, 2, 3] Int
reorder ::
  forall a ds s.
  ( HasShape ds,
    HasShape s,
    HasShape (Reorder s ds),
    CheckReorder ds s
  ) =>
  Proxy ds ->
  Array s a ->
  Array (Reorder s ds) a
reorder :: Proxy ds -> Array s a -> Array (Reorder s ds) a
reorder Proxy ds
_ Array s a
a = (Rep (Array (Reorder s ds)) -> a) -> Array (Reorder s ds) a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> a
Rep (Array (Reorder s ds)) -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [] [Int]
ds [Int]
s)
    ds :: [Int]
ds = Shape ds -> [Int]
forall (s :: [Nat]). Shape s -> [Int]
shapeVal (HasShape ds => Shape ds
forall (s :: [Nat]). HasShape s => Shape s
toShape @ds)

-- | Product two arrays using the supplied binary function.
--
-- For context, if the function is multiply, and the arrays are tensors,
-- then this can be interpreted as a tensor product.
--
-- https://en.wikipedia.org/wiki/Tensor_product
--
-- The concept of a tensor product is a dense crossroad, and a complete treatment is elsewhere.  To quote:
--
-- ... the tensor product can be extended to other categories of mathematical objects in addition to vector spaces, such as to matrices, tensors, algebras, topological vector spaces, and modules. In each such case the tensor product is characterized by a similar universal property: it is the freest bilinear operation. The general concept of a "tensor product" is captured by monoidal categories; that is, the class of all things that have a tensor product is a monoidal category.
--
-- >>> expand (*) v v
-- [[1, 2, 3],
--  [2, 4, 6],
--  [3, 6, 9]]
expand ::
  forall s s' a b c.
  ( HasShape s,
    HasShape s',
    HasShape ((++) s s')
  ) =>
  (a -> b -> c) ->
  Array s a ->
  Array s' b ->
  Array ((++) s s') c
expand :: (a -> b -> c) -> Array s a -> Array s' b -> Array (s ++ s') c
expand a -> b -> c
f Array s a
a Array s' b
b = (Rep (Array (s ++ s')) -> c) -> Array (s ++ s') c
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (\Rep (Array (s ++ s'))
i -> a -> b -> c
f (Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
Rep (Array (s ++ s'))
i)) (Array s' b -> Rep (Array s') -> b
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s' b
b (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
r [Int]
Rep (Array (s ++ s'))
i)))
  where
    r :: Int
r = [Int] -> Int
forall a. [a] -> Int
rank (Array s a -> [Int]
forall a (s :: [Nat]). HasShape s => Array s a -> [Int]
shape Array s a
a)

-- | Apply an array of functions to each array of values.
--
-- This is in the spirit of the applicative functor operation (<*>).
--
-- > expand f a b == apply (fmap f a) b
--
-- >>> apply ((*) <$> v) v
-- [[1, 2, 3],
--  [2, 4, 6],
--  [3, 6, 9]]
--
-- Arrays can't be applicative functors in haskell because the changes in shape are reflected in the types.
--
-- > :t apply
-- apply
--   :: (HasShape s, HasShape s', HasShape (s ++ s')) =>
--      Array s (a -> b) -> Array s' a -> Array (s ++ s') b
-- > :t (<*>)
-- (<*>) :: Applicative f => f (a -> b) -> f a -> f b
--
-- >>> let b = [1..6] :: Array '[2,3] Int
-- >>> contract sum (Proxy :: Proxy '[1,2]) (apply (fmap (*) b) (transpose b))
-- [[14, 32],
--  [32, 77]]
apply ::
  forall s s' a b.
  ( HasShape s,
    HasShape s',
    HasShape ((++) s s')
  ) =>
  Array s (a -> b) ->
  Array s' a ->
  Array ((++) s s') b
apply :: Array s (a -> b) -> Array s' a -> Array (s ++ s') b
apply Array s (a -> b)
f Array s' a
a = (Rep (Array (s ++ s')) -> b) -> Array (s ++ s') b
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate (\Rep (Array (s ++ s'))
i -> Array s (a -> b) -> Rep (Array s) -> a -> b
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s (a -> b)
f (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
Rep (Array (s ++ s'))
i) (Array s' a -> Rep (Array s') -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s' a
a (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
r [Int]
Rep (Array (s ++ s'))
i)))
  where
    r :: Int
r = [Int] -> Int
forall a. [a] -> Int
rank (Array s (a -> b) -> [Int]
forall a (s :: [Nat]). HasShape s => Array s a -> [Int]
shape Array s (a -> b)
f)

-- | Contract an array by applying the supplied (folding) function on diagonal elements of the dimensions.
--
-- This generalises a tensor contraction by allowing the number of contracting diagonals to be other than 2, and allowing a binary operator other than multiplication.
--
-- >>> let b = [1..6] :: Array '[2,3] Int
-- >>> contract sum (Proxy :: Proxy '[1,2]) (expand (*) b (transpose b))
-- [[14, 32],
--  [32, 77]]
contract ::
  forall a b s ss s' ds.
  ( KnownNat (Minimum (TakeIndexes s ds)),
    HasShape (TakeIndexes s ds),
    HasShape s,
    HasShape ds,
    HasShape ss,
    HasShape s',
    s' ~ DropIndexes s ds,
    ss ~ '[Minimum (TakeIndexes s ds)]
  ) =>
  (Array ss a -> b) ->
  Proxy ds ->
  Array s a ->
  Array s' b
contract :: (Array ss a -> b) -> Proxy ds -> Array s a -> Array s' b
contract Array ss a -> b
f Proxy ds
xs Array s a
a = Array ss a -> b
f (Array ss a -> b)
-> (Array (TakeIndexes s ds) a -> Array ss a)
-> Array (TakeIndexes s ds) a
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array (TakeIndexes s ds) a -> Array ss a
forall a (s :: [Nat]).
(HasShape s, HasShape '[Minimum s]) =>
Array s a -> Array '[Minimum s] a
diag (Array (TakeIndexes s ds) a -> b)
-> Array s' (Array (TakeIndexes s ds) a) -> Array s' b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Proxy ds -> Array s a -> Array s' (Array (TakeIndexes s ds) a)
forall (ds :: [Nat]) (st :: [Nat]) (si :: [Nat]) (so :: [Nat]) a.
(HasShape st, HasShape ds, HasShape si, HasShape so,
 so ~ DropIndexes st ds, si ~ TakeIndexes st ds) =>
Proxy ds -> Array st a -> Array so (Array si a)
extractsExcept Proxy ds
xs Array s a
a

-- | A generalisation of a dot operation, which is a multiplicative expansion of two arrays and sum contraction along the middle two dimensions.
--
-- matrix multiplication
--
-- >>> let b = [1..6] :: Array '[2,3] Int
-- >>> dot sum (*) b (transpose b)
-- [[14, 32],
--  [32, 77]]
--
-- inner product
--
-- >>> let v = [1..3] :: Array '[3] Int
-- >>> :t dot sum (*) v v
-- dot sum (*) v v :: Array '[] Int
--
-- >>> dot sum (*) v v
-- 14
--
-- matrix-vector multiplication
-- (Note how the vector doesn't need to be converted to a row or column vector)
--
-- >>> dot sum (*) v b
-- [9, 12, 15]
--
-- >>> dot sum (*) b v
-- [14, 32]
--
-- dot allows operation on mis-shaped matrices:
--
-- >>> let m23 = [1..6] :: Array '[2,3] Int
-- >>> let m12 = [1,2] :: Array '[1,2] Int
-- >>> shape $ dot sum (*) m23 m12
-- [2,2]
--
-- the algorithm ignores excess positions within the contracting dimension(s):
--
-- m23 shape: 2 3
--
-- m12 shape: 1 2
--
-- res shape: 2 2
--
-- FIXME: work out whether this is a feature or a bug...
--
-- find instances of a vector in a matrix
--
-- >>> let cs = fromList ("abacbaab" :: [Char]) :: Array '[4,2] Char
-- >>> let v = fromList ("ab" :: [Char]) :: Vector 2 Char
-- >>> dot (all id) (==) cs v
-- [True, False, False, True]
dot ::
  forall a b c d sa sb s' ss se.
  ( HasShape sa,
    HasShape sb,
    HasShape (sa ++ sb),
    se ~ TakeIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
    HasShape se,
    KnownNat (Minimum se),
    KnownNat (Rank sa - 1),
    KnownNat (Rank sa),
    ss ~ '[Minimum se],
    HasShape ss,
    s' ~ DropIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
    HasShape s'
  ) =>
  (Array ss c -> d) ->
  (a -> b -> c) ->
  Array sa a ->
  Array sb b ->
  Array s' d
dot :: (Array ss c -> d)
-> (a -> b -> c) -> Array sa a -> Array sb b -> Array s' d
dot Array ss c -> d
f a -> b -> c
g Array sa a
a Array sb b
b = (Array ss c -> d)
-> Proxy '[Rank sa - 1, Rank sa]
-> Array (sa ++ sb) c
-> Array s' d
forall a b (s :: [Nat]) (ss :: [Nat]) (s' :: [Nat]) (ds :: [Nat]).
(KnownNat (Minimum (TakeIndexes s ds)),
 HasShape (TakeIndexes s ds), HasShape s, HasShape ds, HasShape ss,
 HasShape s', s' ~ DropIndexes s ds,
 ss ~ '[Minimum (TakeIndexes s ds)]) =>
(Array ss a -> b) -> Proxy ds -> Array s a -> Array s' b
contract Array ss c -> d
f (Proxy '[Rank sa - 1, Rank sa]
forall k (t :: k). Proxy t
Proxy :: Proxy '[Rank sa - 1, Rank sa]) ((a -> b -> c) -> Array sa a -> Array sb b -> Array (sa ++ sb) c
forall (s :: [Nat]) (s' :: [Nat]) a b c.
(HasShape s, HasShape s', HasShape (s ++ s')) =>
(a -> b -> c) -> Array s a -> Array s' b -> Array (s ++ s') c
expand a -> b -> c
g Array sa a
a Array sb b
b)

-- | Array multiplication.
--
-- matrix multiplication
--
-- >>> let b = [1..6] :: Array '[2,3] Int
-- >>> mult b (transpose b)
-- [[14, 32],
--  [32, 77]]
--
-- inner product
--
-- >>> let v = [1..3] :: Array '[3] Int
-- >>> :t mult v v
-- mult v v :: Array '[] Int
--
-- >>> mult v v
-- 14
--
-- matrix-vector multiplication
--
-- >>> mult v b
-- [9, 12, 15]
--
-- >>> mult b v
-- [14, 32]
mult ::
  forall a sa sb s' ss se.
  ( Additive a,
    Multiplicative a,
    HasShape sa,
    HasShape sb,
    HasShape (sa ++ sb),
    se ~ TakeIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
    HasShape se,
    KnownNat (Minimum se),
    KnownNat (Rank sa - 1),
    KnownNat (Rank sa),
    ss ~ '[Minimum se],
    HasShape ss,
    s' ~ DropIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
    HasShape s'
  ) =>
  Array sa a ->
  Array sb a ->
  Array s' a
mult :: Array sa a -> Array sb a -> Array s' a
mult = (Array '[Minimum se] a -> a)
-> (a -> a -> a) -> Array sa a -> Array sb a -> Array s' a
forall a b c d (sa :: [Nat]) (sb :: [Nat]) (s' :: [Nat])
       (ss :: [Nat]) (se :: [Nat]).
(HasShape sa, HasShape sb, HasShape (sa ++ sb),
 se ~ TakeIndexes (sa ++ sb) '[Rank sa - 1, Rank sa], HasShape se,
 KnownNat (Minimum se), KnownNat (Rank sa - 1), KnownNat (Rank sa),
 ss ~ '[Minimum se], HasShape ss,
 s' ~ DropIndexes (sa ++ sb) '[Rank sa - 1, Rank sa],
 HasShape s') =>
(Array ss c -> d)
-> (a -> b -> c) -> Array sa a -> Array sb b -> Array s' d
dot Array '[Minimum se] a -> a
forall a (f :: Type -> Type). (Additive a, Foldable f) => f a -> a
sum a -> a -> a
forall a. Multiplicative a => a -> a -> a
(*)

-- | Select elements along positions in every dimension.
--
-- >>> let s = slice (Proxy :: Proxy '[[0,1],[0,2],[1,2]]) a
-- >>> :t s
-- s :: Array '[2, 2, 2] Int
--
-- >>> s
-- [[[2, 3],
--   [10, 11]],
--  [[14, 15],
--   [22, 23]]]
--
-- >>> let s = squeeze $ slice (Proxy :: Proxy '[ '[0], '[0], '[0]]) a
-- >>> :t s
-- s :: Array '[] Int
--
-- >>> s
-- 1
slice ::
  forall (pss :: [[Nat]]) s s' a.
  ( HasShape s,
    HasShape s',
    KnownNatss pss,
    KnownNat (Rank pss),
    s' ~ Ranks pss
  ) =>
  Proxy pss ->
  Array s a ->
  Array s' a
slice :: Proxy pss -> Array s a -> Array s' a
slice Proxy pss
pss Array s a
a = (Rep (Array s') -> a) -> Array s' a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> a
Rep (Array s') -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array s a -> Rep (Array s) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array s a
a (([Int] -> Int -> Int) -> [[Int]] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [Int] -> Int -> Int
forall a. [a] -> Int -> a
(!!) [[Int]]
pss' [Int]
s)
    pss' :: [[Int]]
pss' = Proxy pss -> [[Int]]
forall (ns :: [[Nat]]). KnownNatss ns => Proxy ns -> [[Int]]
natValss Proxy pss
pss

-- | Remove single dimensions.
--
-- >>> let a = [1..24] :: Array '[2,1,3,4,1] Int
-- >>> a
-- [[[[[1],
--     [2],
--     [3],
--     [4]],
--    [[5],
--     [6],
--     [7],
--     [8]],
--    [[9],
--     [10],
--     [11],
--     [12]]]],
--  [[[[13],
--     [14],
--     [15],
--     [16]],
--    [[17],
--     [18],
--     [19],
--     [20]],
--    [[21],
--     [22],
--     [23],
--     [24]]]]]
-- >>> squeeze a
-- [[[1, 2, 3, 4],
--   [5, 6, 7, 8],
--   [9, 10, 11, 12]],
--  [[13, 14, 15, 16],
--   [17, 18, 19, 20],
--   [21, 22, 23, 24]]]
--
-- >>> squeeze ([1] :: Array '[1,1] Double)
-- 1.0
squeeze ::
  forall s t a.
  (t ~ Squeeze s) =>
  Array s a ->
  Array t a
squeeze :: Array s a -> Array t a
squeeze (Array Vector a
x) = Vector a -> Array t a
forall k (s :: k) a. Vector a -> Array s a
Array Vector a
x

-- $scalar
-- Scalar specialisations

-- | <https://en.wikipedia.org/wiki/Scalarr_(mathematics) Wiki Scalar>
--
-- An Array '[] a despite being a Scalar is never-the-less a one-element vector under the hood. Unification of representation is unexplored.
type Scalar a = Array ('[] :: [Nat]) a

-- | Unwrapping scalars is probably a performance bottleneck.
--
-- >>> let s = [3] :: Array ('[] :: [Nat]) Int
-- >>> fromScalar s
-- 3
fromScalar :: (HasShape ('[] :: [Nat])) => Array ('[] :: [Nat]) a -> a
fromScalar :: Array '[] a -> a
fromScalar Array '[] a
a = Array '[] a -> Rep (Array '[]) -> a
forall (f :: Type -> Type) a. Representable f => f a -> Rep f -> a
index Array '[] a
a ([] :: [Int])

-- | Convert a number to a scalar.
--
-- >>> :t toScalar 2
-- toScalar 2 :: FromInteger a => Array '[] a
toScalar :: (HasShape ('[] :: [Nat])) => a -> Array ('[] :: [Nat]) a
toScalar :: a -> Array '[] a
toScalar a
a = [Item (Array '[] a)] -> Array '[] a
forall l. IsList l => [Item l] -> l
fromList [a
Item (Array '[] a)
a]

-- | <https://en.wikipedia.org/wiki/Vector_(mathematics_and_physics) Wiki Vector>
type Vector s a = Array '[s] a

-- | <https://en.wikipedia.org/wiki/Matrix_(mathematics) Wiki Matrix>
type Matrix m n a = Array '[m, n] a

instance
  ( Multiplicative a,
    P.Distributive a,
    Subtractive a,
    KnownNat m,
    HasShape '[m, m]
  ) =>
  Multiplicative (Matrix m m a)
  where
  * :: Matrix m m a -> Matrix m m a -> Matrix m m a
(*) = Matrix m m a -> Matrix m m a -> Matrix m m a
forall (m :: Nat) (n :: Nat) (k :: Nat) a.
(KnownNat k, KnownNat m, KnownNat n, HasShape '[m, n], Ring a) =>
Array '[m, k] a -> Array '[k, n] a -> Array '[m, n] a
mmult

  one :: Matrix m m a
one = Matrix m m a
forall a (s :: [Nat]).
(HasShape s, Additive a, Multiplicative a) =>
Array s a
ident

-- | Extract specialised to a matrix.
--
-- >>> row 1 m
-- [4, 5, 6, 7]
row :: forall m n a. (KnownNat m, KnownNat n, HasShape '[m, n]) => Int -> Matrix m n a -> Vector n a
row :: Int -> Matrix m n a -> Vector n a
row Int
i (Array Vector a
a) = Vector a -> Vector n a
forall k (s :: k) a. Vector a -> Array s a
Array (Vector a -> Vector n a) -> Vector a -> Vector n a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int
i Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n) Int
n Vector a
a
  where
    n :: Int
n = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @n Proxy n
forall k (t :: k). Proxy t
Proxy

-- | Row extraction checked at type level.
--
-- >>> safeRow (Proxy :: Proxy 1) m
-- [4, 5, 6, 7]
--
-- >>> safeRow (Proxy :: Proxy 3) m
-- ...
-- ... index outside range
-- ...
safeRow :: forall m n a j. ('True ~ CheckIndex j m, KnownNat j, KnownNat m, KnownNat n, HasShape '[m, n]) => Proxy j -> Matrix m n a -> Vector n a
safeRow :: Proxy j -> Matrix m n a -> Vector n a
safeRow Proxy j
_j (Array Vector a
a) = Vector a -> Vector n a
forall k (s :: k) a. Vector a -> Array s a
Array (Vector a -> Vector n a) -> Vector a -> Vector n a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int
j Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n) Int
n Vector a
a
  where
    n :: Int
n = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @n Proxy n
forall k (t :: k). Proxy t
Proxy
    j :: Int
j = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy j -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @j Proxy j
forall k (t :: k). Proxy t
Proxy

-- | Extract specialised to a matrix.
--
-- >>> col 1 m
-- [1, 5, 9]
col :: forall m n a. (KnownNat m, KnownNat n, HasShape '[m, n]) => Int -> Matrix m n a -> Vector n a
col :: Int -> Matrix m n a -> Vector n a
col Int
i (Array Vector a
a) = Vector a -> Vector n a
forall k (s :: k) a. Vector a -> Array s a
Array (Vector a -> Vector n a) -> Vector a -> Vector n a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
m (\Int
x -> Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
a (Int
i Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
x Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n))
  where
    m :: Int
m = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy m -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @m Proxy m
forall k (t :: k). Proxy t
Proxy
    n :: Int
n = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @n Proxy n
forall k (t :: k). Proxy t
Proxy

-- | Column extraction checked at type level.
--
-- >>> safeCol (Proxy :: Proxy 1) m
-- [1, 5, 9]
--
-- >>> safeCol (Proxy :: Proxy 4) m
-- ...
-- ... index outside range
-- ...
safeCol :: forall m n a j. ('True ~ CheckIndex j n, KnownNat j, KnownNat m, KnownNat n, HasShape '[m, n]) => Proxy j -> Matrix m n a -> Vector n a
safeCol :: Proxy j -> Matrix m n a -> Vector n a
safeCol Proxy j
_j (Array Vector a
a) = Vector a -> Vector n a
forall k (s :: k) a. Vector a -> Array s a
Array (Vector a -> Vector n a) -> Vector a -> Vector n a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
m (\Int
x -> Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
a (Int
j Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
x Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n))
  where
    m :: Int
m = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy m -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @m Proxy m
forall k (t :: k). Proxy t
Proxy
    n :: Int
n = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @n Proxy n
forall k (t :: k). Proxy t
Proxy
    j :: Int
j = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy j -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @j Proxy j
forall k (t :: k). Proxy t
Proxy

-- | Matrix multiplication.
--
-- This is dot sum (*) specialised to matrices
--
-- >>> let a = [1, 2, 3, 4] :: Array '[2, 2] Int
-- >>> let b = [5, 6, 7, 8] :: Array '[2, 2] Int
-- >>> a
-- [[1, 2],
--  [3, 4]]
--
-- >>> b
-- [[5, 6],
--  [7, 8]]
--
-- >>> mmult a b
-- [[19, 22],
--  [43, 50]]
mmult ::
  forall m n k a.
  ( KnownNat k,
    KnownNat m,
    KnownNat n,
    HasShape [m, n],
    Ring a
  ) =>
  Array [m, k] a ->
  Array [k, n] a ->
  Array [m, n] a
mmult :: Array '[m, k] a -> Array '[k, n] a -> Array '[m, n] a
mmult (Array Vector a
x) (Array Vector a
y) = (Rep (Array '[m, n]) -> a) -> Array '[m, n] a
forall (f :: Type -> Type) a.
Representable f =>
(Rep f -> a) -> f a
tabulate [Int] -> a
Rep (Array '[m, n]) -> a
go
  where
    go :: [Int] -> a
go [] = NumHaskException -> a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
    go [Int
_] = NumHaskException -> a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
    go (Int
i : Int
j : [Int]
_) = Vector a -> a
forall a (f :: Type -> Type). (Additive a, Foldable f) => f a -> a
sum (Vector a -> a) -> Vector a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> a -> a
forall a. Multiplicative a => a -> a -> a
(*) (Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral Int
i Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
k) Int
k Vector a
x) (Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
k (\Int
x' -> Vector a
y Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! (Int -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral Int
j Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
x' Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n)))
    n :: Int
n = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @n Proxy n
forall k (t :: k). Proxy t
Proxy
    k :: Int
k = Integer -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy k -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal @k Proxy k
forall k (t :: k). Proxy t
Proxy
{-# INLINE mmult #-}