```{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.Arithmetic.LinearAlgebra where

import qualified Data.Array.Accelerate.Utility.Loop as Loop
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Arrange as Arrange
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate
(Acc, Array, Exp, Any(Any), All(All), Z(Z), (:.)((:.)))

type Scalar ix a = Acc (Array ix a)
type Vector ix a = Acc (Array (ix :. Int) a)
type Matrix ix a = Acc (Array (ix :. Int :. Int) a)

transpose ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Matrix ix a -> Matrix ix a
transpose m =
A.backpermute
(A.lift \$ swapIndex \$ matrixShape m)
(A.lift . swapIndex . A.unlift)
m

swapIndex ::
Exp ix :. Exp Int :. Exp Int ->
Exp ix :. Exp Int :. Exp Int
swapIndex (ix :. r :. c) = (ix :. c :. r)

numElems :: (A.Shape ix, A.Slice ix, A.Elt a) => Vector ix a -> Exp Int
numElems m = case vectorShape m of _ix :. n -> n

numRows :: (A.Shape ix, A.Slice ix, A.Elt a) => Matrix ix a -> Exp Int
numRows m = case matrixShape m of _ix :. rows :. _cols -> rows

numCols :: (A.Shape ix, A.Slice ix, A.Elt a) => Matrix ix a -> Exp Int
numCols m = case matrixShape m of _ix :. _rows :. cols -> cols

vectorShape ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Vector ix a -> Exp ix :. Exp Int
vectorShape m = A.unlift \$ A.shape m

matrixShape ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Matrix ix a -> Exp ix :. Exp Int :. Exp Int
matrixShape m = A.unlift \$ A.shape m

withVectorIndex ::
(A.Shape ix, A.Slice ix, A.Lift Exp a) =>
(Exp ix :. Exp Int -> a) ->
(Exp (ix :. Int) -> Exp (A.Plain a))
withVectorIndex f = A.lift . f . A.unlift

withMatrixIndex ::
(A.Shape ix, A.Slice ix, A.Lift Exp a) =>
(Exp ix :. Exp Int :. Exp Int -> a) ->
(Exp (ix :. Int :. Int) -> Exp (A.Plain a))
withMatrixIndex f = A.lift . f . A.unlift

outer ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
Vector ix a -> Vector ix a -> Matrix ix a
outer x y =
A.zipWith (*)
(A.replicate (A.lift \$ Any :. All :. numElems y) x)
(A.replicate (A.lift \$ Any :. numElems x :. All) y)

multiplyMatrixVector ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
Matrix ix a ->
Vector ix a ->
Vector ix a
multiplyMatrixVector m v =
case matrixShape m of
(_ix :. rows :. _cols) ->
A.fold1 (+) \$
A.zipWith (*) m
(A.replicate (A.lift \$ Any :. rows :. All) v)

multiplyMatrixMatrix ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
Matrix ix a ->
Matrix ix a ->
Matrix ix a
multiplyMatrixMatrix x y =
case (matrixShape x, matrixShape y) of
(_ :. rows :. _cols, _ :. _rows :. cols) ->
A.fold1 (+) \$ transpose \$
A.zipWith (*)
(A.replicate (A.lift \$ Any :. All :. All :. cols) x)
(A.replicate (A.lift \$ Any :. rows :. All :. All) y)

newtonInverseStep ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
Matrix ix a ->
Matrix ix a ->
Matrix ix a
newtonInverseStep a x =
A.zipWith (-) (A.map (2*) x) \$
multiplyMatrixMatrix x \$ multiplyMatrixMatrix a x

identity ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
Exp (ix :. Int :. Int) -> Matrix ix a
identity sh =
A.generate sh
(withMatrixIndex \$
\(_ :. r :. c) -> A.fromIntegral \$ A.boolToInt (r A.==* c))

newtonInverse ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
Exp Int ->
Matrix ix a ->
Matrix ix a ->
Matrix ix a
newtonInverse n seed a =
Loop.nest n (newtonInverseStep a) seed

scaleRows ::
(A.Slice ix, A.Shape ix, A.Elt a, A.IsNum a) =>
Vector ix a -> Matrix ix a -> Matrix ix a
scaleRows s x =
zipScalarVectorWith (*) s x

zipScalarVectorWith ::
(A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
(Exp a -> Exp b -> Exp c) ->
Scalar ix a -> Vector ix b -> Vector ix c
zipScalarVectorWith f x ys =
case vectorShape ys of
_ix :. dim ->
A.zipWith f (A.replicate (A.lift (Any :. dim)) x) ys

zipScalarMatrixWith ::
(A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
(Exp a -> Exp b -> Exp c) ->
Scalar ix a -> Matrix ix b -> Matrix ix c
zipScalarMatrixWith f x ys =
case matrixShape ys of
_ix :. rows :. cols ->
A.zipWith f
(A.replicate (A.lift (Any :. rows :. cols)) x) ys

columnFromVector ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Vector ix a -> Matrix ix a
columnFromVector a = A.reshape (Exp.indexCons (A.shape a) 1) a

{- |
input must be a matrix with exactly one column
-}
vectorFromColumn ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Matrix ix a -> Vector ix a
vectorFromColumn a = A.reshape (A.indexTail \$ A.shape a) a

flattenMatrix, flattenMatrixReshape, flattenMatrixBackPermute ::
(A.Slice ix, A.Shape ix, A.Elt a) =>
Matrix ix a -> Vector ix a
flattenMatrix = flattenMatrixBackPermute

flattenMatrixReshape m =
case matrixShape m of
ix :. rows :. cols ->
A.reshape (A.lift \$ ix :. rows*cols) m

accDivMod :: Integral a => a -> a -> (a, a)
accDivMod x y = (div x y, mod x y)

flattenMatrixBackPermute m =
case matrixShape m of
ix :. rows :. cols ->
A.backpermute
(A.lift \$ ix :. rows*cols)
(withVectorIndex \$
\(vix :. n) -> case accDivMod n cols of (r,c) -> vix :. r :. c)
m

restoreMatrix, restoreMatrixReshape, restoreMatrixBackPermute ::
(A.Slice ix, A.Shape ix, A.Elt a) =>
Exp Int -> Vector ix a -> Matrix ix a
restoreMatrix = restoreMatrixBackPermute

restoreMatrixReshape cols v =
case vectorShape v of
ix :. n ->
A.reshape (A.lift \$ ix :. div n cols :. cols) v

restoreMatrixBackPermute cols v =
case vectorShape v of
ix :. n ->
A.backpermute
(A.lift \$ ix :. div n cols :. cols)
(withMatrixIndex \$ \(vix :. k :. j) -> vix :. k*cols+j)
v

extrudeVector ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Exp ix -> Vector Z a -> Vector ix a
extrudeVector shape y =
-- A.replicate (A.lift \$ shape :. All) y
A.backpermute
(A.lift \$ shape :. numElems y)
y

extrudeMatrix ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Exp ix -> Matrix Z a -> Matrix ix a
extrudeMatrix shape y =
A.backpermute
(A.lift \$ shape :. numRows y :. numCols y)
(withMatrixIndex \$ \(_:.r:.c) -> Z:.r:.c)
y

zipExtrudedVectorWith ::
(A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
(Exp a -> Exp b -> Exp c) ->
Vector Z a ->
Vector ix b ->
Vector ix c
zipExtrudedVectorWith f x y =
A.zipWith f (extrudeVector (A.indexTail \$ A.shape y) x) y

zipExtrudedMatrixWith ::
(A.Slice ix, A.Shape ix, A.Elt a, A.Elt b, A.Elt c) =>
(Exp a -> Exp b -> Exp c) ->
Matrix Z a ->
Matrix ix b ->
Matrix ix c
zipExtrudedMatrixWith f x y =
A.zipWith f (extrudeMatrix (A.indexTail \$ A.indexTail \$ A.shape y) x) y

gatherFromVector ::
(A.Shape ix, A.Elt a) =>
Scalar ix Int -> Vector Z a -> Scalar ix a
gatherFromVector indices =
Arrange.gather (A.map A.index1 indices)
```