module Data.Shape where
import Data.Monoid
import Data.Int (Int32)
import qualified Data.Vector.Unboxed as VU
import GHC.TypeLits (Nat)
import qualified Data.Dim as Dim
data Z
data sh :# e
data sh :. e
data Sh sh where
Z :: Sh Z
D :: Sh sh -> Dim.Dd Int32 -> Sh (sh :# Int32)
S :: Sh sh -> Dim.Sd Int32 -> Sh (sh :. Int32)
type D1 = Z :# Int32
type S1 = Z :. Int32
type D2 = (Z :# Int32) :# Int32
type CSR = (Z :# Int32) :. Int32
type COO = (Z :. Int32) :. Int32
instance Show (Sh sh) where
show Z = ""
show (D sh (Dim.Dd m)) = unwords [show m, show sh]
show (S sh (Dim.Sd _ ix n)) = showSparse ix n <> show sh where
showSparse ixx nn = show (VU.length ixx, nn)
instance Eq (Sh sh) where
Z == Z = True
(sh `D` d) == (sh2 `D` d2) = d == d2 && (sh == sh2)
(sh `S` s) == (sh2 `S` s2) = s == s2 && (sh == sh2)
rank :: Sh sh -> Int
rank Z = 0
rank (D sh _) = 1 + rank sh
rank (S sh _) = 1 + rank sh
dim :: Sh sh -> [Int]
dim Z = []
dim (D sh (Dim.Dd m)) = fromIntegral m : dim sh
dim (S sh (Dim.Sd _ _ m)) = fromIntegral m : dim sh
mkD1 :: Int32 -> Sh D1
mkD1 m = Z `D` Dim.Dd m
mkS1 :: Int32 -> VU.Vector Int32 -> VU.Vector Int32 -> Sh S1
mkS1 m segv ixv = Z `S` Dim.Sd (Just segv) ixv m
mkD2 :: Int32 -> Int32 -> Sh D2
mkD2 m n = (Z `D` Dim.Dd m) `D` Dim.Dd n
mkCSR :: Int32 -> Int32 -> VU.Vector Int32 -> VU.Vector Int32 -> Sh CSR
mkCSR m n icml iidx = (Z `D` Dim.Dd m) `S` Dim.Sd (Just icml) iidx n
mkCOO :: Int32 -> Int32 -> VU.Vector Int32 -> VU.Vector Int32 -> Sh COO
mkCOO m n vi vj = (Z `S` Dim.Sd Nothing vi m) `S` Dim.Sd Nothing vj n