{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}

-- | Arrays with a dynamic shape.
module NumHask.Array.Dynamic
  ( -- $usage
    Array (..),

    -- * Conversion
    fromFlatList,
    toFlatList,

    -- * representable replacements
    index,
    tabulate,

    -- * Operators
    reshape,
    transpose,
    diag,
    ident,
    singleton,
    selects,
    selectsExcept,
    folds,
    extracts,
    extractsExcept,
    joins,
    maps,
    concatenate,
    insert,
    append,
    reorder,
    expand,
    apply,
    contract,
    dot,
    mult,
    slice,
    squeeze,

    -- * Scalar

    --
    -- Scalar specialisations
    fromScalar,
    toScalar,

    -- * Matrix

    --
    -- Matrix specialisations.
    col,
    row,
    mmult,
  )
where

import Data.List (intercalate)
import qualified Data.Vector as V
import GHC.Show (Show (..))
import NumHask.Array.Shape
import NumHask.Prelude as P hiding (product)

-- $setup
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> :set -XFlexibleContexts
-- >>> :set -XRebindableSyntax
-- >>> import NumHask.Prelude
-- >>> import NumHask.Array.Dynamic
-- >>> import NumHask.Array.Shape
-- >>> let s = fromFlatList [] [1] :: Array Int
-- >>> let a = fromFlatList [2,3,4] [1..24] :: Array Int
-- >>> let v = fromFlatList [3] [1,2,3] :: Array Int
-- >>> let m = fromFlatList [3,4] [0..11] :: Array Int

-- $usage
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> :set -XFlexibleContexts
-- >>> :set -XRebindableSyntax
-- >>> import NumHask.Prelude
-- >>> import NumHask.Array.Dynamic
-- >>> import NumHask.Array.Shape
-- >>> let s = fromFlatList [] [1] :: Array Int
-- >>> let a = fromFlatList [2,3,4] [1..24] :: Array Int
-- >>> let v = fromFlatList [3] [1,2,3] :: Array Int
-- >>> let m = fromFlatList [3,4] [0..11] :: Array Int

-- | a multidimensional array with a value-level shape
--
-- >>> let a = fromFlatList [2,3,4] [1..24] :: Array Int
-- >>> a
-- [[[1, 2, 3, 4],
--   [5, 6, 7, 8],
--   [9, 10, 11, 12]],
--  [[13, 14, 15, 16],
--   [17, 18, 19, 20],
--   [21, 22, 23, 24]]]
data Array a = Array {Array a -> [Int]
shape :: [Int], Array a -> Vector a
unArray :: V.Vector a}
  deriving (Array a -> Array a -> Bool
(Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool) -> Eq (Array a)
forall a. Eq a => Array a -> Array a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Array a -> Array a -> Bool
$c/= :: forall a. Eq a => Array a -> Array a -> Bool
== :: Array a -> Array a -> Bool
$c== :: forall a. Eq a => Array a -> Array a -> Bool
Eq, Eq (Array a)
Eq (Array a)
-> (Array a -> Array a -> Ordering)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Array a)
-> (Array a -> Array a -> Array a)
-> Ord (Array a)
Array a -> Array a -> Bool
Array a -> Array a -> Ordering
Array a -> Array a -> Array a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Array a)
forall a. Ord a => Array a -> Array a -> Bool
forall a. Ord a => Array a -> Array a -> Ordering
forall a. Ord a => Array a -> Array a -> Array a
min :: Array a -> Array a -> Array a
$cmin :: forall a. Ord a => Array a -> Array a -> Array a
max :: Array a -> Array a -> Array a
$cmax :: forall a. Ord a => Array a -> Array a -> Array a
>= :: Array a -> Array a -> Bool
$c>= :: forall a. Ord a => Array a -> Array a -> Bool
> :: Array a -> Array a -> Bool
$c> :: forall a. Ord a => Array a -> Array a -> Bool
<= :: Array a -> Array a -> Bool
$c<= :: forall a. Ord a => Array a -> Array a -> Bool
< :: Array a -> Array a -> Bool
$c< :: forall a. Ord a => Array a -> Array a -> Bool
compare :: Array a -> Array a -> Ordering
$ccompare :: forall a. Ord a => Array a -> Array a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Array a)
Ord, (forall x. Array a -> Rep (Array a) x)
-> (forall x. Rep (Array a) x -> Array a) -> Generic (Array a)
forall x. Rep (Array a) x -> Array a
forall x. Array a -> Rep (Array a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Array a) x -> Array a
forall a x. Array a -> Rep (Array a) x
$cto :: forall a x. Rep (Array a) x -> Array a
$cfrom :: forall a x. Array a -> Rep (Array a) x
Generic)

instance Functor Array where
  fmap :: (a -> b) -> Array a -> Array b
fmap a -> b
f (Array [Int]
s Vector a
a) = [Int] -> Vector b -> Array b
forall a. [Int] -> Vector a -> Array a
Array [Int]
s ((a -> b) -> Vector a -> Vector b
forall a b. (a -> b) -> Vector a -> Vector b
V.map a -> b
f Vector a
a)

instance Foldable Array where
  foldr :: (a -> b -> b) -> b -> Array a -> b
foldr a -> b -> b
x b
a (Array [Int]
_ Vector a
v) = (a -> b -> b) -> b -> Vector a -> b
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr a -> b -> b
x b
a Vector a
v

instance Traversable Array where
  traverse :: (a -> f b) -> Array a -> f (Array b)
traverse a -> f b
f (Array [Int]
s Vector a
v) =
    [Int] -> [b] -> Array b
forall a. [Int] -> [a] -> Array a
fromFlatList [Int]
s ([b] -> Array b) -> f [b] -> f (Array b)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f b) -> [a] -> f [b]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f (Vector a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
toList Vector a
v)

instance (Show a) => Show (Array a) where
  show :: Array a -> String
show a :: Array a
a@(Array [Int]
l Vector a
_) = Int -> Array a -> String
forall a. Show a => Int -> Array a -> String
go ([Int] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Int]
l) Array a
a
    where
      go :: Int -> Array a -> String
go Int
n a' :: Array a
a'@(Array [Int]
l' Vector a
m) =
        case [Int] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Int]
l' of
          Int
0 -> a -> String
forall a. Show a => a -> String
GHC.Show.show (Vector a -> a
forall a. Vector a -> a
V.head Vector a
m)
          Int
1 -> String
"[" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " (a -> String
forall a. Show a => a -> String
GHC.Show.show (a -> String) -> [a] -> [String]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
m) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
          Int
x ->
            String
"["
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate
                (String
",\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
n Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
x Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
1) Char
' ')
                (Int -> Array a -> String
go Int
n (Array a -> String) -> [Array a] -> [String]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Array (Array a) -> [Array a]
forall a. Array a -> [a]
toFlatList ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int
0] Array a
a'))
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"

-- * conversions

-- | convert from a list
--
-- >>> fromFlatList [2,3,4] [1..24] == a
-- True
fromFlatList :: [Int] -> [a] -> Array a
fromFlatList :: [Int] -> [a] -> Array a
fromFlatList [Int]
ds [a]
l = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array [Int]
ds (Vector a -> Array a) -> Vector a -> Array a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList ([a] -> Vector a) -> [a] -> Vector a
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take ([Int] -> Int
size [Int]
ds) [a]
l

-- | convert to a flat list.
--
-- >>> toFlatList a == [1..24]
-- True
toFlatList :: Array a -> [a]
toFlatList :: Array a -> [a]
toFlatList (Array [Int]
_ Vector a
v) = Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
v

-- | extract an element at index /i/
--
-- >>> index a [1,2,3]
-- 24
index :: () => Array a -> [Int] -> a
index :: Array a -> [Int] -> a
index (Array [Int]
s Vector a
v) [Int]
i = Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v ([Int] -> [Int] -> Int
flatten [Int]
s [Int]
i)

-- | tabulate an array with a generating function
--
-- >>> tabulate [2,3,4] ((1+) . flatten [2,3,4]) == a
-- True
tabulate :: () => [Int] -> ([Int] -> a) -> Array a
tabulate :: [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds [Int] -> a
f = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array [Int]
ds (Vector a -> Array a)
-> ((Int -> a) -> Vector a) -> (Int -> a) -> Array a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate ([Int] -> Int
size [Int]
ds) ((Int -> a) -> Array a) -> (Int -> a) -> Array a
forall a b. (a -> b) -> a -> b
$ ([Int] -> a
f ([Int] -> a) -> (Int -> [Int]) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int -> [Int]
shapen [Int]
ds)

-- | Reshape an array (with the same number of elements).
--
-- >>> reshape [4,3,2] a
-- [[[1, 2],
--   [3, 4],
--   [5, 6]],
--  [[7, 8],
--   [9, 10],
--   [11, 12]],
--  [[13, 14],
--   [15, 16],
--   [17, 18]],
--  [[19, 20],
--   [21, 22],
--   [23, 24]]]
reshape ::
  [Int] ->
  Array a ->
  Array a
reshape :: [Int] -> Array a -> Array a
reshape [Int]
s Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
s (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> a) -> ([Int] -> [Int]) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int -> [Int]
shapen (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int] -> Int
flatten [Int]
s)

-- | Reverse indices eg transposes the element A/ijk/ to A/kji/.
--
-- >>> index (transpose a) [1,0,0] == index a [0,0,1]
-- True
transpose :: Array a -> Array a
transpose :: Array a -> Array a
transpose Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> a) -> ([Int] -> [Int]) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. [a] -> [a]
reverse)

-- | The identity array.
--
-- >>> ident [3,2]
-- [[1, 0],
--  [0, 1],
--  [0, 0]]
ident :: (Additive a, Multiplicative a) => [Int] -> Array a
ident :: [Int] -> Array a
ident [Int]
ds = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds (a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
forall a. Additive a => a
zero a
forall a. Multiplicative a => a
one (Bool -> a) -> ([Int] -> Bool) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Bool
forall a. Eq a => [a] -> Bool
isDiag)
  where
    isDiag :: [a] -> Bool
isDiag [] = Bool
True
    isDiag [a
_] = Bool
True
    isDiag [a
x, a
y] = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y
    isDiag (a
x : a
y : [a]
xs) = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y Bool -> Bool -> Bool
&& [a] -> Bool
isDiag (a
y a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs)

-- | Extract the diagonal of an array.
--
-- >>> diag (ident [3,2])
-- [1, 1]
diag ::
  Array a ->
  Array a
diag :: Array a -> Array a
diag Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [[Int] -> Int
NumHask.Array.Shape.minimum (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)] [Int] -> a
go
  where
    go :: [Int] -> a
go [] = NumHaskException -> a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Rank Underflow")
    go (Int
s' : [Int]
_) = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate ([Int] -> Int
forall a. [a] -> Int
rank (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Int
s')

-- | Create an array composed of a single value.
--
-- >>> singleton [3,2] one
-- [[1, 1],
--  [1, 1],
--  [1, 1]]
singleton :: [Int] -> a -> Array a
singleton :: [Int] -> a -> Array a
singleton [Int]
ds a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds (a -> [Int] -> a
forall a b. a -> b -> a
const a
a)

-- | Select an array along dimensions.
--
-- >>> let s = selects [0,1] [1,1] a
-- >>> s
-- [17, 18, 19, 20]
selects ::
  [Int] ->
  [Int] ->
  Array a ->
  Array a
selects :: [Int] -> [Int] -> Array a -> Array a
selects [Int]
ds [Int]
i Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
dropIndexes (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
s [Int]
ds [Int]
i)

-- | Select an index /except/ along specified dimensions
--
-- >>> let s = selectsExcept [2] [1,1] a
-- >>> s
-- [17, 18, 19, 20]
selectsExcept ::
  [Int] ->
  [Int] ->
  Array a ->
  Array a
selectsExcept :: [Int] -> [Int] -> Array a -> Array a
selectsExcept [Int]
ds [Int]
i Array a
a = [Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
selects (Int -> [Int] -> [Int]
exclude ([Int] -> Int
forall a. [a] -> Int
rank (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) [Int]
ds) [Int]
i Array a
a

-- | Fold along specified dimensions.
--
-- >>> folds sum [1] a
-- [68, 100, 132]
folds ::
  (Array a -> b) ->
  [Int] ->
  Array a ->
  Array b
folds :: (Array a -> b) -> [Int] -> Array a -> Array b
folds Array a -> b
f [Int]
ds Array a
a = [Int] -> ([Int] -> b) -> Array b
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
takeIndexes (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> b
go
  where
    go :: [Int] -> b
go [Int]
s = Array a -> b
f ([Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
selects [Int]
ds [Int]
s Array a
a)

-- | Extracts dimensions to an outer layer.
--
-- >>> let e = extracts [1,2] a
-- >>> shape <$> extracts [0] a
-- [[3,4], [3,4]]
extracts ::
  [Int] ->
  Array a ->
  Array (Array a)
extracts :: [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a = [Int] -> ([Int] -> Array a) -> Array (Array a)
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
takeIndexes (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> Array a
go
  where
    go :: [Int] -> Array a
go [Int]
s = [Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
selects [Int]
ds [Int]
s Array a
a

-- | Extracts /except/ dimensions to an outer layer.
--
-- >>> let e = extractsExcept [1,2] a
-- >>> shape <$> extracts [0] a
-- [[3,4], [3,4]]
extractsExcept ::
  [Int] ->
  Array a ->
  Array (Array a)
extractsExcept :: [Int] -> Array a -> Array (Array a)
extractsExcept [Int]
ds Array a
a = [Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts (Int -> [Int] -> [Int]
exclude ([Int] -> Int
forall a. [a] -> Int
rank (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) [Int]
ds) Array a
a

-- | Join inner and outer dimension layers.
--
-- >>> let e = extracts [1,0] a
-- >>> let j = joins [1,0] e
-- >>> a == j
-- True
joins ::
  [Int] ->
  Array (Array a) ->
  Array a
joins :: [Int] -> Array (Array a) -> Array a
joins [Int]
ds Array (Array a)
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
si [Int]
ds [Int]
so) [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index (Array (Array a) -> [Int] -> Array a
forall a. Array a -> [Int] -> a
index Array (Array a)
a ([Int] -> [Int] -> [Int]
takeIndexes [Int]
s [Int]
ds)) ([Int] -> [Int] -> [Int]
dropIndexes [Int]
s [Int]
ds)
    so :: [Int]
so = Array (Array a) -> [Int]
forall a. Array a -> [Int]
shape Array (Array a)
a
    si :: [Int]
si = Array a -> [Int]
forall a. Array a -> [Int]
shape (Array (Array a) -> [Int] -> Array a
forall a. Array a -> [Int] -> a
index Array (Array a)
a (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate ([Int] -> Int
forall a. [a] -> Int
rank [Int]
so) Int
0))

-- | Maps a function along specified dimensions.
--
-- >>> shape $ maps (transpose) [1] a
-- [4,3,2]
maps ::
  (Array a -> Array b) ->
  [Int] ->
  Array a ->
  Array b
maps :: (Array a -> Array b) -> [Int] -> Array a -> Array b
maps Array a -> Array b
f [Int]
ds Array a
a = [Int] -> Array (Array b) -> Array b
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds ((Array a -> Array b) -> Array (Array a) -> Array (Array b)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Array a -> Array b
f ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a))

-- | Concatenate along a dimension.
--
-- >>> shape $ concatenate 1 a a
-- [2,6,4]
concatenate ::
  Int ->
  Array a ->
  Array a ->
  Array a
concatenate :: Int -> Array a -> Array a -> Array a
concatenate Int
d Array a
a0 Array a
a1 = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Int -> [Int] -> [Int] -> [Int]
concatenate' Int
d (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a0) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a1)) [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s =
      a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool
        (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a0 [Int]
s)
        ( Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index
            Array a
a1
            ( [Int] -> Int -> Int -> [Int]
addIndex
                ([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
                Int
d
                (([Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d) Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- ([Int]
ds0 [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d))
            )
        )
        (([Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= ([Int]
ds0 [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d))
    ds0 :: [Int]
ds0 = Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a0

-- | Insert along a dimension at a position.
--
-- >>> insert 2 0 a (fromFlatList [2,3] [100..105])
-- [[[100, 1, 2, 3, 4],
--   [101, 5, 6, 7, 8],
--   [102, 9, 10, 11, 12]],
--  [[103, 13, 14, 15, 16],
--   [104, 17, 18, 19, 20],
--   [105, 21, 22, 23, 24]]]
insert ::
  Int ->
  Int ->
  Array a ->
  Array a ->
  Array a
insert :: Int -> Int -> Array a -> Array a -> Array a
insert Int
d Int
i Array a
a Array a
b = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Int -> [Int] -> [Int]
incAt Int
d (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s
      | [Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
b ([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
      | [Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a [Int]
s
      | Bool
otherwise = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
decAt Int
d [Int]
s)

-- | Insert along a dimension at the end.
--
-- >>> append 2 a (fromFlatList [2,3] [100..105])
-- [[[1, 2, 3, 4, 100],
--   [5, 6, 7, 8, 101],
--   [9, 10, 11, 12, 102]],
--  [[13, 14, 15, 16, 103],
--   [17, 18, 19, 20, 104],
--   [21, 22, 23, 24, 105]]]
append ::
  Int ->
  Array a ->
  Array a ->
  Array a
append :: Int -> Array a -> Array a -> Array a
append Int
d Array a
a Array a
b = Int -> Int -> Array a -> Array a -> Array a
forall a. Int -> Int -> Array a -> Array a -> Array a
insert Int
d ([Int] -> Int -> Int
dimension (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) Int
d) Array a
a Array a
b

-- | change the order of dimensions
--
-- >>> let r = reorder [2,0,1] a
-- >>> r
-- [[[1, 5, 9],
--   [13, 17, 21]],
--  [[2, 6, 10],
--   [14, 18, 22]],
--  [[3, 7, 11],
--   [15, 19, 23]],
--  [[4, 8, 12],
--   [16, 20, 24]]]
reorder ::
  [Int] ->
  Array a ->
  Array a
reorder :: [Int] -> Array a -> Array a
reorder [Int]
ds Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
reorder' (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [] [Int]
ds [Int]
s)

-- | Product two arrays using the supplied binary function.
--
-- For context, if the function is multiply, and the arrays are tensors,
-- then this can be interpreted as a tensor product.
--
-- https://en.wikipedia.org/wiki/Tensor_product
--
-- The concept of a tensor product is a dense crossroad, and a complete treatment is elsewhere.  To quote:
-- ... the tensor product can be extended to other categories of mathematical objects in addition to vector spaces, such as to matrices, tensors, algebras, topological vector spaces, and modules. In each such case the tensor product is characterized by a similar universal property: it is the freest bilinear operation. The general concept of a "tensor product" is captured by monoidal categories; that is, the class of all things that have a tensor product is a monoidal category.
--
-- >>> expand (*) v v
-- [[1, 2, 3],
--  [2, 4, 6],
--  [3, 6, 9]]
expand ::
  (a -> b -> c) ->
  Array a ->
  Array b ->
  Array c
expand :: (a -> b -> c) -> Array a -> Array b -> Array c
expand a -> b -> c
f Array a
a Array b
b = [Int] -> ([Int] -> c) -> Array c
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
(++) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b)) (\[Int]
i -> a -> b -> c
f (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
i)) (Array b -> [Int] -> b
forall a. Array a -> [Int] -> a
index Array b
b (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
r [Int]
i)))
  where
    r :: Int
r = [Int] -> Int
forall a. [a] -> Int
rank (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)

-- | Apply an array of functions to each array of values.
--
-- This is in the spirit of the applicative functor operation (<*>).
--
-- > expand f a b == apply (fmap f a) b
--
-- >>> apply ((*) <$> v) v
-- [[1, 2, 3],
--  [2, 4, 6],
--  [3, 6, 9]]
--
-- >>> let b = fromFlatList [2,3] [1..6] :: Array Int
-- >>> contract sum [1,2] (apply (fmap (*) b) (transpose b))
-- [[14, 32],
--  [32, 77]]
apply ::
  Array (a -> b) ->
  Array a ->
  Array b
apply :: Array (a -> b) -> Array a -> Array b
apply Array (a -> b)
f Array a
a = [Int] -> ([Int] -> b) -> Array b
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
(++) (Array (a -> b) -> [Int]
forall a. Array a -> [Int]
shape Array (a -> b)
f) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) (\[Int]
i -> Array (a -> b) -> [Int] -> a -> b
forall a. Array a -> [Int] -> a
index Array (a -> b)
f (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
i) (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
r [Int]
i)))
  where
    r :: Int
r = [Int] -> Int
forall a. [a] -> Int
rank (Array (a -> b) -> [Int]
forall a. Array a -> [Int]
shape Array (a -> b)
f)

-- | Contract an array by applying the supplied (folding) function on diagonal elements of the dimensions.
--
-- This generalises a tensor contraction by allowing the number of contracting diagonals to be other than 2, and allowing a binary operator other than multiplication.
--
-- >>> let b = fromFlatList [2,3] [1..6] :: Array Int
-- >>> contract sum [1,2] (expand (*) b (transpose b))
-- [[14, 32],
--  [32, 77]]
contract ::
  (Array a -> b) ->
  [Int] ->
  Array a ->
  Array b
contract :: (Array a -> b) -> [Int] -> Array a -> Array b
contract Array a -> b
f [Int]
xs Array a
a = Array a -> b
f (Array a -> b) -> (Array a -> Array a) -> Array a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array a -> Array a
forall a. Array a -> Array a
diag (Array a -> b) -> Array (Array a) -> Array b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extractsExcept [Int]
xs Array a
a

-- | A generalisation of a dot operation, which is a multiplicative expansion of two arrays and sum contraction along the middle two dimensions.
--
-- matrix multiplication
--
-- >>> let b = fromFlatList [2,3] [1..6] :: Array Int
-- >>> dot sum (*) b (transpose b)
-- [[14, 32],
--  [32, 77]]
--
-- inner product
--
-- >>> let v = fromFlatList [3] [1..3] :: Array Int
-- >>> dot sum (*) v v
-- 14
--
-- matrix-vector multiplication
-- Note that an `Array Int` with shape [3] is neither a row vector nor column vector. `dot` is not turning the vector into a matrix and then using matrix multiplication.
--
-- >>> dot sum (*) v b
-- [9, 12, 15]
--
-- >>> dot sum (*) b v
-- [14, 32]
dot ::
  (Array c -> d) ->
  (a -> b -> c) ->
  Array a ->
  Array b ->
  Array d
dot :: (Array c -> d) -> (a -> b -> c) -> Array a -> Array b -> Array d
dot Array c -> d
f a -> b -> c
g Array a
a Array b
b = (Array c -> d) -> [Int] -> Array c -> Array d
forall a b. (Array a -> b) -> [Int] -> Array a -> Array b
contract Array c -> d
f [[Int] -> Int
forall a. [a] -> Int
rank [Int]
sa Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
1, [Int] -> Int
forall a. [a] -> Int
rank [Int]
sa] ((a -> b -> c) -> Array a -> Array b -> Array c
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
expand a -> b -> c
g Array a
a Array b
b)
  where
    sa :: [Int]
sa = Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a

-- | Array multiplication.
--
-- matrix multiplication
--
-- >>> let b = fromFlatList [2,3] [1..6] :: Array Int
-- >>> mult b (transpose b)
-- [[14, 32],
--  [32, 77]]
--
-- inner product
--
-- >>> let v = fromFlatList [3] [1..3] :: Array Int
-- >>> mult v v
-- 14
--
-- matrix-vector multiplication
--
-- >>> mult v b
-- [9, 12, 15]
--
-- >>> mult b v
-- [14, 32]
mult ::
  ( Additive a,
    Multiplicative a
  ) =>
  Array a ->
  Array a ->
  Array a
mult :: Array a -> Array a -> Array a
mult = (Array a -> a) -> (a -> a -> a) -> Array a -> Array a -> Array a
forall c d a b.
(Array c -> d) -> (a -> b -> c) -> Array a -> Array b -> Array d
dot Array a -> a
forall a (f :: Type -> Type). (Additive a, Foldable f) => f a -> a
sum a -> a -> a
forall a. Multiplicative a => a -> a -> a
(*)

-- | Select elements along positions in every dimension.
--
-- >>> let s = slice [[0,1],[0,2],[1,2]] a
-- >>> s
-- [[[2, 3],
--   [10, 11]],
--  [[14, 15],
--   [22, 23]]]
slice ::
  [[Int]] ->
  Array a ->
  Array a
slice :: [[Int]] -> Array a -> Array a
slice [[Int]]
pss Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([[Int]] -> [Int]
forall a. [[a]] -> [Int]
ranks [[Int]]
pss) [Int] -> a
go
  where
    go :: [Int] -> a
go [Int]
s = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (([Int] -> Int -> Int) -> [[Int]] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [Int] -> Int -> Int
forall a. [a] -> Int -> a
(!!) [[Int]]
pss [Int]
s)

-- | Remove single dimensions.
--
-- >>> let a' = fromFlatList [2,1,3,4,1] [1..24] :: Array Int
-- >>> shape $ squeeze a'
-- [2,3,4]
squeeze ::
  Array a ->
  Array a
squeeze :: Array a -> Array a
squeeze (Array [Int]
s Vector a
x) = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array ([Int] -> [Int]
forall a. (Eq a, Multiplicative a) => [a] -> [a]
squeeze' [Int]
s) Vector a
x

-- | Unwrapping scalars is probably a performance bottleneck.
--
-- >>> let s = fromFlatList [] [3] :: Array Int
-- >>> fromScalar s
-- 3
fromScalar :: Array a -> a
fromScalar :: Array a -> a
fromScalar Array a
a = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([] :: [Int])

-- | Convert a number to a scalar.
--
-- >>> :t toScalar 2
-- toScalar 2 :: FromInteger a => Array a
toScalar :: a -> Array a
toScalar :: a -> Array a
toScalar a
a = [Int] -> [a] -> Array a
forall a. [Int] -> [a] -> Array a
fromFlatList [] [a
a]

-- | Extract specialised to a matrix.
--
-- >>> row 1 m
-- [4, 5, 6, 7]
row :: Int -> Array a -> Array a
row :: Int -> Array a -> Array a
row Int
i (Array [Int]
s Vector a
a) = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array [Int
n] (Vector a -> Array a) -> Vector a -> Array a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int
i Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n) Int
n Vector a
a
  where
    (Int
_ : Int
n : [Int]
_) = [Int]
s

-- | extract specialised to a matrix
--
-- >>> col 1 m
-- [1, 5, 9]
col :: Int -> Array a -> Array a
col :: Int -> Array a -> Array a
col Int
i (Array [Int]
s Vector a
a) = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array [Int
m] (Vector a -> Array a) -> Vector a -> Array a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
m (\Int
x -> Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
a (Int
i Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
x Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n))
  where
    (Int
m : Int
n : [Int]
_) = [Int]
s

-- | matrix multiplication
--
-- This is dot sum (*) specialised to matrices
--
-- >>> let a = fromFlatList [2,2] [1, 2, 3, 4] :: Array Int
-- >>> let b = fromFlatList [2,2] [5, 6, 7, 8] :: Array Int
-- >>> a
-- [[1, 2],
--  [3, 4]]
--
-- >>> b
-- [[5, 6],
--  [7, 8]]
--
-- >>> mmult a b
-- [[19, 22],
--  [43, 50]]
mmult ::
  (Ring a) =>
  Array a ->
  Array a ->
  Array a
mmult :: Array a -> Array a -> Array a
mmult (Array [Int]
sx Vector a
x) (Array [Int]
sy Vector a
y) = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int
m, Int
n] [Int] -> a
go
  where
    go :: [Int] -> a
go [] = NumHaskException -> a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
    go [Int
_] = NumHaskException -> a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
    go (Int
i : Int
j : [Int]
_) = Vector a -> a
forall a (f :: Type -> Type). (Additive a, Foldable f) => f a -> a
sum (Vector a -> a) -> Vector a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> a -> a
forall a. Multiplicative a => a -> a -> a
(*) (Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral Int
i Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
k) Int
k Vector a
x) (Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
k (\Int
x' -> Vector a
y Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! (Int -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral Int
j Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
x' Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n)))
    (Int
m : Int
k : [Int]
_) = [Int]
sx
    (Int
_ : Int
n : [Int]
_) = [Int]
sy
{-# INLINE mmult #-}