{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Data.Tensor.Vector.Internal where import Data.Cardinal import Data.TypeList.MultiIndex hiding (drop, take) import Data.TypeList.MultiIndex.Internal import Data.Ordinal import Data.Tensor hiding (Tensor) import qualified Data.Tensor as T import qualified Data.Vector as V data Tensor i e = Tensor { form :: [Int] , content :: V.Vector e } deriving Eq type Vector n = Tensor (n :|: Nil) type Matrix m n = Tensor (m :|: (n :|: Nil)) type ColumnVector n = Matrix n One vector2ColumnVector :: Vector n e -> ColumnVector n e vector2ColumnVector (Tensor ds x) = (Tensor (ds ++ [1]) x) columnVector2Vector :: ColumnVector n e -> Vector n e columnVector2Vector (Tensor ds x) = (Tensor (init ds) x) type RowVector n = Matrix One n vector2RowVector :: Vector n e -> RowVector n e vector2RowVector (Tensor ds x) = (Tensor (1 : ds) x) rowVector2Vector :: RowVector n e -> Vector n e rowVector2Vector (Tensor ds x) = (Tensor (tail ds) x) instance Show e => Show (Tensor i e) where showsPrec _ (Tensor [] v) = shows $ v V.! 0 showsPrec _ (Tensor ds v) = let sd = Prelude.reverse ds in let l = V.length v in let r = Prelude.length ds in showsT sd l (Prelude.replicate r 1) 1 . (shows $ v V.! (l-1)) . (Prelude.replicate r ']' ++) where showsT sd l ys n = let (zs,k) = match sd ys in if n < l then showsT sd l zs (n+1) . (shows $ v V.! (l-n-1)) . (Prelude.replicate k ']' ++) . (',':) . (Prelude.replicate k '[' ++) else (Prelude.replicate k '[' ++) match is js = match' is js [] 0 where match' [] _ zs n = (zs,n) match' _ [] zs n = (zs,n) match' (x:xs) (y:ys) zs n | x == y = match' xs ys (zs ++ [1]) (n+1) | otherwise = (zs ++ ((y+1):ys),n) instance Functor (Tensor i) where fmap f (Tensor is v) = Tensor is (fmap f v) instance Zip (Tensor i) where zipWith f (Tensor d x) (Tensor _ y) = Tensor d $ V.zipWith f x y class FromVector t where fromVector :: V.Vector e -> t e instance (Bounded i, Cardinality i, MultiIndex i) => FromVector (Tensor i) where fromVector x = toTensor maxBound x where toTensor :: (Cardinality i, MultiIndex i) => i -> V.Vector e -> Tensor i e toTensor i v | V.length v < l = error ("fromVector: length of vector \ \must be at least " ++ show l) | otherwise = Tensor (dimensions i) (V.take l v) where l = card i instance (Bounded i, Cardinality i, MultiIndex i) => FromList (Tensor i) where fromList = fromVector . V.fromList instance (Bounded i, MultiIndex i) => T.Tensor (Tensor i e) where type Index (Tensor i e) = i type Elem (Tensor i e) = e dims _ = maxBound (Tensor _ x) ! j = x V.! multiIndex2Linear j generate f = t where t = Tensor d $ V.generate (product d) (f . toMultiIndex . (unlinearize d)) d = fromMultiIndex $ dims $ asTypeOf undefined t replicate e = t where t = Tensor d $ V.replicate (product d) e d = fromMultiIndex $ dims $ asTypeOf undefined t instance (Cardinal n, MultiIndex i, MultiIndex j, MultiIndexConcat n i j) => DirectSum n (Tensor i e) (Tensor j e) where type SumSpace n (Tensor i e) (Tensor j e) = (Tensor (Concat n i j) e) directSum n (Tensor d x) (Tensor d' y) = Tensor ((take i d) ++ e'') (V.generate l g) where l = product ((take i d) ++ e'') e = drop i d e' = drop i d' e'' = ((d !! i) + (d' !! i)) : (drop (i+1) d) m = product e m' = product e' m'' = product e'' g k = if rem k m'' < m then x V.! ((quot k m'')*m + (rem k m'')) else y V.! ((quot k m'')*m' + (rem k m'') - m) i = fromCardinal n instance (Ordinal i, Ordinal j) => Transpose (Tensor (i :|: (j :|: Nil)) e) where type TransposeSpace (Tensor (i :|: (j :|: Nil)) e) = (Tensor (j :|: (i :|: Nil)) e) transpose (Tensor ds x) = Tensor [d2,d1] (V.generate (d1*d2) t) where t n = let (q,r) = quotRem n d1 in x V.! (r * d2 + q) d1 = head ds d2 = head $ tail ds unsafeTensorGet :: [Int] -> Tensor i e -> e unsafeTensorGet is (Tensor ds x) = x V.! linearize ds is unsafeTensorGen :: [Int] -> ([Int] -> e) -> Tensor i e unsafeTensorGen ds f = Tensor ds $ V.generate (product ds) (f . (unlinearize ds)) unsafeVectorGet :: Int -> Vector i e -> e unsafeVectorGet i (Tensor _ x) = x V.! i unsafeVectorGen :: Int -> (Int -> e) -> Vector i e unsafeVectorGen d f = Tensor [d] $ V.generate d f unsafeMatrixGet :: Int -> Int -> Matrix i j e -> e unsafeMatrixGet i j (Tensor ds x) = x V.! linearize ds [i,j] unsafeMatrixGen :: Int -> Int -> (Int -> Int -> e) -> Matrix i j e unsafeMatrixGen d e f = Tensor [d,e] $ V.generate (d*e) (\n -> let [i,j] = unlinearize [d,e] n in f i j ) unsafeMatrixGetRow :: Int -> Matrix i j e -> Vector j e unsafeMatrixGetRow i (Tensor ds x) = Tensor (tail ds) $ V.slice ((i-1)*d) d x where d = last ds unsafeMatrixRowSwitch :: Int -> Int -> Matrix i j e -> Matrix i j e unsafeMatrixRowSwitch i1 i2 (Tensor ds x) = Tensor ds $ V.generate (d*e) rs where d = head ds e = last ds rs n | quot n e == i1 - 1 = x V.! (n + off) | quot n e == i2 - 1 = x V.! (n - off) | otherwise = x V.! n off = e * (i2 - i1) unsafeMatrixColSwitch :: Int -> Int -> Matrix i j e -> Matrix i j e unsafeMatrixColSwitch j1 j2 (Tensor ds x) = Tensor ds $ V.generate (d*e) cs where d = head ds e = last ds cs n | rem n e == j1 - 1 = x V.! (n + off) | rem n e == j2 - 1 = x V.! (n - off) | otherwise = x V.! n off = j2 - j1 unsafeMatrixRowMult :: (Num e) => Int -> e -> Matrix i j e -> Matrix i j e unsafeMatrixRowMult i a (Tensor ds x) = Tensor ds $ V.generate (d*e) rm where d = head ds e = last ds rm n = if quot n e == i - 1 then (x V.! n) * a else x V.! n unsafeMatrixColMult :: (Num e) => Int -> e -> Matrix i j e -> Matrix i j e unsafeMatrixColMult j a (Tensor ds x) = Tensor ds $ V.generate (d*e) cm where d = head ds e = last ds cm n = if rem n e == j - 1 then (x V.! n) * a else x V.! n unsafeMatrixRowDiv :: (Fractional e) => Int -> e -> Matrix i j e -> Matrix i j e unsafeMatrixRowDiv i a (Tensor ds x) = Tensor ds $ V.generate (d*e) rd where d = head ds e = last ds rd n = if quot n e == i - 1 then (x V.! n) / a else x V.! n unsafeMatrixColDiv :: (Fractional e) => Int -> e -> Matrix i j e -> Matrix i j e unsafeMatrixColDiv j a (Tensor ds x) = Tensor ds $ V.generate (d*e) cd where d = head ds e = last ds cd n = if rem n e == j - 1 then (x V.! n) / a else x V.! n unsafeMatrixRowAdd :: (Num e) => Int -> e -> Int -> Matrix i j e -> Matrix i j e unsafeMatrixRowAdd i1 a i2 (Tensor ds x) = Tensor ds $ V.generate (d*e) ra where d = head ds e = last ds ra n | quot n e == i1 - 1 = x V.! n + (x V.! (n + off))*a | otherwise = x V.! n off = e * (i2 - i1) unsafeMatrixColAdd :: (Num e) => Int -> e -> Int -> Matrix i j e -> Matrix i j e unsafeMatrixColAdd j1 a j2 (Tensor ds x) = Tensor ds $ V.generate (d*e) ca where d = head ds e = last ds ca n | rem n e == j1 - 1 = x V.! n + (x V.! (n + off))*a | otherwise = x V.! n off = j2 - j1 unsafeMatrixRowSub :: (Num e) => Int -> e -> Int -> Matrix i j e -> Matrix i j e unsafeMatrixRowSub i1 a i2 (Tensor ds x) = Tensor ds $ V.generate (d*e) rs where d = head ds e = last ds rs n | quot n e == i1 - 1 = x V.! n - (x V.! (n + off))*a | otherwise = x V.! n off = e * (i2 - i1) unsafeMatrixColSub :: (Num e) => Int -> e -> Int -> Matrix i j e -> Matrix i j e unsafeMatrixColSub j1 a j2 (Tensor ds x) = Tensor ds $ V.generate (d*e) cs where d = head ds e = last ds cs n | rem n e == j1 - 1 = x V.! n - (x V.! (n + off))*a | otherwise = x V.! n off = j2 - j1