module Data.Tensor.Vector.Internal where
import Data.Cardinal
import Data.TypeList.MultiIndex hiding (drop, take)
import Data.TypeList.MultiIndex.Internal
import Data.Ordinal
import Data.Tensor hiding (Tensor)
import qualified Data.Tensor as T
import qualified Data.Vector as V
data Tensor i e = Tensor
{ form :: [Int]
, content :: V.Vector e
} deriving Eq
type Vector n = Tensor (n :|: Nil)
type Matrix m n = Tensor (m :|: (n :|: Nil))
type ColumnVector n = Matrix n One
vector2ColumnVector :: Vector n e -> ColumnVector n e
vector2ColumnVector (Tensor ds x) = (Tensor (ds ++ [1]) x)
columnVector2Vector :: ColumnVector n e -> Vector n e
columnVector2Vector (Tensor ds x) = (Tensor (init ds) x)
type RowVector n = Matrix One n
vector2RowVector :: Vector n e -> RowVector n e
vector2RowVector (Tensor ds x) = (Tensor (1 : ds) x)
rowVector2Vector :: RowVector n e -> Vector n e
rowVector2Vector (Tensor ds x) = (Tensor (tail ds) x)
instance Show e => Show (Tensor i e) where
showsPrec _ (Tensor [] v) = shows $ v V.! 0
showsPrec _ (Tensor ds v) = let sd = Prelude.reverse ds in
let l = V.length v in
let r = Prelude.length ds in
showsT sd l (Prelude.replicate r 1) 1 .
(shows $ v V.! (l1)) .
(Prelude.replicate r ']' ++)
where showsT sd l ys n = let (zs,k) = match sd ys in
if n < l
then showsT sd l zs (n+1) .
(shows $ v V.! (ln1)) .
(Prelude.replicate k ']' ++) .
(',':) . (Prelude.replicate k '[' ++)
else (Prelude.replicate k '[' ++)
match is js = match' is js [] 0
where match' [] _ zs n = (zs,n)
match' _ [] zs n = (zs,n)
match' (x:xs) (y:ys) zs n | x == y =
match' xs ys (zs ++ [1]) (n+1)
| otherwise =
(zs ++ ((y+1):ys),n)
instance Functor (Tensor i) where
fmap f (Tensor is v) = Tensor is (fmap f v)
instance Zip (Tensor i) where
zipWith f (Tensor d x) (Tensor _ y) = Tensor d $ V.zipWith f x y
class FromVector t where
fromVector :: V.Vector e -> t e
instance (Bounded i, Cardinality i, MultiIndex i) =>
FromVector (Tensor i) where
fromVector x = toTensor maxBound x
where
toTensor :: (Cardinality i, MultiIndex i) =>
i -> V.Vector e -> Tensor i e
toTensor i v | V.length v < l = error ("fromVector: length of vector \
\must be at least " ++ show l)
| otherwise = Tensor (dimensions i) (V.take l v)
where l = card i
instance (Bounded i, Cardinality i, MultiIndex i) =>
FromList (Tensor i) where
fromList = fromVector . V.fromList
instance (Bounded i, MultiIndex i) => T.Tensor (Tensor i e) where
type Index (Tensor i e) = i
type Elem (Tensor i e) = e
dims _ = maxBound
(Tensor _ x) ! j = x V.! multiIndex2Linear j
generate f = t
where t = Tensor d $
V.generate (product d) (f . toMultiIndex . (unlinearize d))
d = fromMultiIndex $ dims $ asTypeOf undefined t
replicate e = t
where t = Tensor d $
V.replicate (product d) e
d = fromMultiIndex $ dims $ asTypeOf undefined t
instance (Cardinal n, MultiIndex i, MultiIndex j, MultiIndexConcat n i j)
=> DirectSum n (Tensor i e) (Tensor j e) where
type SumSpace n (Tensor i e) (Tensor j e) = (Tensor (Concat n i j) e)
directSum n (Tensor d x) (Tensor d' y) = Tensor ((take i d) ++ e'') (V.generate l g)
where l = product ((take i d) ++ e'')
e = drop i d
e' = drop i d'
e'' = ((d !! i) + (d' !! i)) : (drop (i+1) d)
m = product e
m' = product e'
m'' = product e''
g k = if rem k m'' < m
then x V.! ((quot k m'')*m + (rem k m''))
else y V.! ((quot k m'')*m' + (rem k m'') m)
i = fromCardinal n
instance (Ordinal i, Ordinal j) => Transpose (Tensor (i :|: (j :|: Nil)) e)
where
type TransposeSpace (Tensor (i :|: (j :|: Nil)) e) = (Tensor (j :|: (i :|: Nil)) e)
transpose (Tensor ds x) = Tensor [d2,d1] (V.generate (d1*d2) t)
where t n = let (q,r) = quotRem n d1
in x V.! (r * d2 + q)
d1 = head ds
d2 = head $ tail ds
unsafeTensorGet :: [Int] -> Tensor i e -> e
unsafeTensorGet is (Tensor ds x) = x V.! linearize ds is
unsafeTensorGen :: [Int] -> ([Int] -> e) -> Tensor i e
unsafeTensorGen ds f =
Tensor ds $ V.generate (product ds) (f . (unlinearize ds))
unsafeVectorGet :: Int -> Vector i e -> e
unsafeVectorGet i (Tensor _ x) = x V.! i
unsafeVectorGen :: Int -> (Int -> e) -> Vector i e
unsafeVectorGen d f = Tensor [d] $ V.generate d f
unsafeMatrixGet :: Int -> Int -> Matrix i j e -> e
unsafeMatrixGet i j (Tensor ds x) = x V.! linearize ds [i,j]
unsafeMatrixGen :: Int -> Int -> (Int -> Int -> e) -> Matrix i j e
unsafeMatrixGen d e f = Tensor [d,e] $ V.generate (d*e)
(\n -> let [i,j] = unlinearize [d,e] n in
f i j
)
unsafeMatrixGetRow :: Int -> Matrix i j e -> Vector j e
unsafeMatrixGetRow i (Tensor ds x) = Tensor (tail ds) $
V.slice ((i1)*d) d x
where d = last ds
unsafeMatrixRowSwitch :: Int -> Int -> Matrix i j e -> Matrix i j e
unsafeMatrixRowSwitch i1 i2 (Tensor ds x) = Tensor ds $ V.generate (d*e) rs
where d = head ds
e = last ds
rs n | quot n e == i1 1 = x V.! (n + off)
| quot n e == i2 1 = x V.! (n off)
| otherwise = x V.! n
off = e * (i2 i1)
unsafeMatrixColSwitch :: Int -> Int -> Matrix i j e -> Matrix i j e
unsafeMatrixColSwitch j1 j2 (Tensor ds x) = Tensor ds $ V.generate (d*e) cs
where d = head ds
e = last ds
cs n | rem n e == j1 1 = x V.! (n + off)
| rem n e == j2 1 = x V.! (n off)
| otherwise = x V.! n
off = j2 j1
unsafeMatrixRowMult :: (Num e) => Int -> e -> Matrix i j e -> Matrix i j e
unsafeMatrixRowMult i a (Tensor ds x) = Tensor ds $ V.generate (d*e) rm
where d = head ds
e = last ds
rm n = if quot n e == i 1
then (x V.! n) * a
else x V.! n
unsafeMatrixColMult :: (Num e) => Int -> e -> Matrix i j e -> Matrix i j e
unsafeMatrixColMult j a (Tensor ds x) = Tensor ds $ V.generate (d*e) cm
where d = head ds
e = last ds
cm n = if rem n e == j 1
then (x V.! n) * a
else x V.! n
unsafeMatrixRowDiv :: (Fractional e) => Int -> e -> Matrix i j e -> Matrix i j e
unsafeMatrixRowDiv i a (Tensor ds x) = Tensor ds $ V.generate (d*e) rd
where d = head ds
e = last ds
rd n = if quot n e == i 1
then (x V.! n) / a
else x V.! n
unsafeMatrixColDiv :: (Fractional e) => Int -> e -> Matrix i j e -> Matrix i j e
unsafeMatrixColDiv j a (Tensor ds x) = Tensor ds $ V.generate (d*e) cd
where d = head ds
e = last ds
cd n = if rem n e == j 1
then (x V.! n) / a
else x V.! n
unsafeMatrixRowAdd :: (Num e) => Int -> e -> Int -> Matrix i j e -> Matrix i j e
unsafeMatrixRowAdd i1 a i2 (Tensor ds x) = Tensor ds $ V.generate (d*e) ra
where d = head ds
e = last ds
ra n | quot n e == i1 1 = x V.! n + (x V.! (n + off))*a
| otherwise = x V.! n
off = e * (i2 i1)
unsafeMatrixColAdd :: (Num e) => Int -> e -> Int -> Matrix i j e -> Matrix i j e
unsafeMatrixColAdd j1 a j2 (Tensor ds x) = Tensor ds $ V.generate (d*e) ca
where d = head ds
e = last ds
ca n | rem n e == j1 1 = x V.! n + (x V.! (n + off))*a
| otherwise = x V.! n
off = j2 j1
unsafeMatrixRowSub :: (Num e) => Int -> e -> Int -> Matrix i j e -> Matrix i j e
unsafeMatrixRowSub i1 a i2 (Tensor ds x) = Tensor ds $ V.generate (d*e) rs
where d = head ds
e = last ds
rs n | quot n e == i1 1 = x V.! n (x V.! (n + off))*a
| otherwise = x V.! n
off = e * (i2 i1)
unsafeMatrixColSub :: (Num e) => Int -> e -> Int -> Matrix i j e -> Matrix i j e
unsafeMatrixColSub j1 a j2 (Tensor ds x) = Tensor ds $ V.generate (d*e) cs
where d = head ds
e = last ds
cs n | rem n e == j1 1 = x V.! n (x V.! (n + off))*a
| otherwise = x V.! n
off = j2 j1