module Data.Tensor.Static (
IsTensor(..)
, Tensor(..)
, TensorConstructor
, PositiveDims
, fill
, zero
, enumFromN
, EnumFromN
, enumFromStepN
, EnumFromStepN
, generate
, Generate
, dimensions
, elemsNumber
, subtensorsElemsNumbers
, ElemsNumber
, SubtensorsElemsNumbers
, FlattenIndex
, AllIndexes
, NatsFromTo
, NormalizeDims
, withTensor
, add
, Add
, diff
, Diff
, scale
, Scale
, cons
, Cons
, ConsSubtensorDims
, DimsAfterCons
, snoc
, Snoc
, SnocSubtensorDims
, DimsAfterSnoc
, append
, Append
, DimsAfterAppend
, remove
, Remove
, DimsAfterRemove
, NestedList
, toNestedList
, ToNestedList
, tensorElem
, TensorElem
, Subtensor
, SubtensorStartIndex
, SubtensorDims
, subtensor
, SubtensorCtx
, getSubtensor
, GetSubtensor
, getSubtensorElems
, GetSubtensorElems
, setSubtensor
, SetSubtensor
, setSubtensorElems
, SetSubtensorElems
, mapSubtensorElems
, MapSubtensorElems
, SliceEndIndex
, ElemsInSlice
, slice
, Slice
, getSlice
, GetSlice
, getSliceElems
, GetSliceElems
, setSlice
, SetSlice
, setSliceElems
, SetSliceElems
, mapSliceElems
, MapSliceElems
, MonoFunctorCtx
, MonoFoldableCtx
, MonoTraversableCtx
, MonoZipCtx
, unsafeWithTensorPtr
) where
import Control.Lens (Lens', lens, Each(..), traversed)
import Data.Containers (MonoZip(..))
import Data.Function.NAry (NAry, ApplyNAry(..))
import Data.Kind (Type)
import Data.List (intersperse)
import Data.List.Split (chunksOf)
import Data.MonoTraversable (MonoFunctor(..), MonoFoldable(..), MonoTraversable(..), Element)
import Data.Proxy (Proxy(..))
import Data.Singletons.Prelude.List (Tail, Product, Length)
import Data.Type.Equality (type (==))
import Data.Type.Bool (If, type (&&))
import Foreign.Storable (Storable(..))
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Marshal.Utils (with)
import GHC.TypeLits (Nat, KnownNat, natVal, type (+), type (), type (<=?), type (*), TypeError, ErrorMessage(..))
import Type.List (MkCtx, DemoteWith(..), KnownNats(..))
import qualified Data.List.Unrolled as U
natVal' :: forall (n :: Nat). (KnownNat n) => Int
natVal' = fromInteger $ natVal (Proxy @n)
type family PositiveDims (dims :: [Nat]) :: Bool where
PositiveDims '[] = 'True
PositiveDims (d ': ds) = 1 <=? d && PositiveDims ds
type FlattenIndex (index :: [Nat]) (dims :: [Nat]) = FlattenIndex' index (SubtensorsElemsNumbers dims)
type family FlattenIndex' (index :: [Nat]) (elemNumbers :: [Nat]) :: Nat where
FlattenIndex' '[] '[] = 0
FlattenIndex' (i ': is) '[] = TypeError ('Text "FlattenIndex: Too many dimensions in the index for subtensor.")
FlattenIndex' '[] (n ': ns) = TypeError ('Text "FlattenIndex: Not enough dimensions in the index for subtensor.")
FlattenIndex' (i ': is) (n ': ns) = i * n + FlattenIndex' is ns
type family NormalizeDims (dims :: [Nat]) :: [Nat] where
NormalizeDims '[] = '[]
NormalizeDims (1 ': xs) = NormalizeDims xs
NormalizeDims (x ': xs) = x ': NormalizeDims xs
type AllIndexes (dims :: [Nat]) = Sequence (IndexesRanges dims)
type family Sequence (xss :: [[k]]) :: [[k]] where
Sequence '[] = '[ '[] ]
Sequence (x ': xs) = Sequence' x (Sequence xs)
type family Sequence' (xs :: [k]) (yss :: [[k]]) :: [[k]] where
Sequence' '[] _ = '[]
Sequence' (x ': xs) ys = Sequence'' x ys ++ Sequence' xs ys
type family Sequence'' (x :: k) (yss :: [[k]]) :: [[k]] where
Sequence'' _ '[] = '[]
Sequence'' x (y ': ys) = '[ x ': y ] ++ Sequence'' x ys
infixr 5 ++
type family (++) (xs :: [k]) (ys :: [k]) :: [k] where
'[] ++ ys = ys
(x ': xs) ++ ys = x ': (xs ++ ys)
type family IndexesRanges (dims :: [Nat]) :: [[Nat]] where
IndexesRanges '[] = '[]
IndexesRanges (d ': ds) = IndexesRanges' (d ': ds) (1 <=? d)
type family IndexesRanges' (dims :: [Nat]) (dimPositive :: Bool) :: [[Nat]] where
IndexesRanges' (d ': ds) 'True = NatsFromTo 0 (d 1) ': IndexesRanges ds
IndexesRanges' (d ': _) 'False =
TypeError ('Text "IndexesRanges: Tensor has non-positive dimension: " ':<>: 'ShowType d)
type NatsFromTo (from :: Nat) (to :: Nat) = NatsFromTo' from to (from <=? to)
type family NatsFromTo' (from :: Nat) (to :: Nat) (fromLTEto :: Bool) :: [Nat] where
NatsFromTo' _ _ 'False = '[]
NatsFromTo' f t 'True = f ': NatsFromTo' (f + 1) t (f + 1 <=? t)
class (PositiveDims dims ~ 'True, KnownNats dims) => IsTensor (dims :: [Nat]) e where
data Tensor dims e :: Type
tensor :: TensorConstructor dims e
unsafeFromList :: [e] -> Tensor dims e
toList :: Tensor dims e -> [] e
type TensorConstructor (dims :: [Nat]) (e :: Type) = NAry (ElemsNumber dims) e (Tensor dims e)
type family ElemsNumber (dims :: [Nat]) :: Nat where
ElemsNumber '[] = 1
ElemsNumber (d ': ds) = d * ElemsNumber ds
instance IsTensor '[] e where
newtype Tensor '[] e = Scalar e
tensor = Scalar
unsafeFromList (a:_) = Scalar a
unsafeFromList _ = error "Not enough elements to build a Tensor of shape []."
toList (Scalar a) = [a]
instance
( Show (NestedList (Length dims) e)
, IsTensor dims e
, ToNestedListWrk dims e
, KnownNats dims
) =>
Show (Tensor dims e)
where
show t = "Tensor'" ++ dims ++ " " ++ show (toNestedList t)
where
dims = concat $ intersperse "\'" $ map show (dimensions @dims)
instance (Show e) => Show (Tensor '[] e) where
show (Scalar e) = "Scalar " ++ show e
withTensor :: forall dims e r.
( IsTensor dims e
, ApplyNAry (ElemsNumber dims) e r
)
=> Tensor dims e
-> (NAry (ElemsNumber dims) e r)
-> r
withTensor t f = applyNAry @(ElemsNumber dims) @e @r f (toList t)
dimensions :: forall (dims :: [Nat]). (KnownNats dims) => [Int]
dimensions = natsVal @dims
elemsNumber :: forall (dims :: [Nat]). (KnownNat (ElemsNumber dims)) => Int
elemsNumber = natVal' @(ElemsNumber dims)
subtensorsElemsNumbers :: forall (dims :: [Nat]). (KnownNats (SubtensorsElemsNumbers dims)) => [Int]
subtensorsElemsNumbers = natsVal @(SubtensorsElemsNumbers dims)
add :: (Add dims e) => Tensor dims e -> Tensor dims e -> Tensor dims e
add = ozipWith (+)
type Add (dims :: [Nat]) e =
( IsTensor dims e
, Num e
, U.ZipWith (ElemsNumber dims)
, U.Zip (ElemsNumber dims)
, U.Unzip (ElemsNumber dims)
, U.Map (ElemsNumber dims)
)
diff :: (Diff dims e) => Tensor dims e -> Tensor dims e -> Tensor dims e
diff = ozipWith ()
type Diff (dims :: [Nat]) e =
( IsTensor dims e
, Num e
, U.ZipWith (ElemsNumber dims)
, U.Zip (ElemsNumber dims)
, U.Unzip (ElemsNumber dims)
, U.Map (ElemsNumber dims)
)
scale :: (Scale dims e) => Tensor dims e -> e -> Tensor dims e
scale t k = omap (*k) t
type Scale (dims :: [Nat]) e =
( IsTensor dims e
, Num e
, U.Map (ElemsNumber dims)
)
fill :: forall (dims :: [Nat]) e. (Fill dims e) => e -> Tensor dims e
fill = unsafeFromList . U.replicate @(ElemsNumber dims)
type Fill (dims :: [Nat]) e = (IsTensor dims e, U.Replicate (ElemsNumber dims))
zero :: (Fill dims e, Num e) => Tensor dims e
zero = fill 0
enumFromN :: forall (dims :: [Nat]) e.
(EnumFromN dims e)
=> e
-> Tensor dims e
enumFromN = unsafeFromList . U.enumFromN @(ElemsNumber dims)
type EnumFromN (dims :: [Nat]) e =
( IsTensor dims e
, U.EnumFromN (ElemsNumber dims)
, Num e
)
enumFromStepN :: forall (dims :: [Nat]) e.
(EnumFromStepN dims e)
=> e
-> e
-> Tensor dims e
enumFromStepN a = unsafeFromList . U.enumFromStepN @(ElemsNumber dims) a
type EnumFromStepN (dims :: [Nat]) e =
( IsTensor dims e
, U.EnumFromStepN (ElemsNumber dims)
, Num e
)
generate :: forall (dims :: [Nat]) (e :: Type) (kctx :: Type) (ctx :: kctx).
(Generate dims e kctx ctx)
=> (forall (index :: [Nat]).
(MkCtx [Nat] kctx ctx index)
=> Proxy index
-> e
)
-> Tensor dims e
generate f = unsafeFromList (demoteWith @[Nat] @kctx @ctx @(AllIndexes dims) f)
type Generate (dims :: [Nat]) (e :: Type) (kctx :: Type) (ctx :: kctx) =
( IsTensor dims e
, DemoteWith [Nat] kctx ctx (AllIndexes dims)
)
type family NestedList (depth :: Nat) (e :: Type) :: Type where
NestedList 0 e = e
NestedList n e = [NestedList (n 1) e]
toNestedList :: forall dims e. (ToNestedList dims e)
=> Tensor dims e
-> NestedList (Length dims) e
toNestedList = toNestedListWrk @dims @e . toList
type ToNestedList (dims :: [Nat]) e = (IsTensor dims e, ToNestedListWrk dims e)
class ToNestedListWrk (dims :: [Nat]) e where
toNestedListWrk :: [e] -> NestedList (Length dims) e
instance ToNestedListWrk '[] e where
toNestedListWrk = head
instance ToNestedListWrk '[x] e where
toNestedListWrk = id
instance ( ToNestedListWrk (xx ': xs) e
, KnownNat (Product (xx ': xs))
, NestedList (Length (x ': xx ': xs)) e ~ [NestedList (Length (xx ': xs)) e]
) =>
ToNestedListWrk (x ': xx ': xs) e
where
toNestedListWrk xs = map (toNestedListWrk @(xx ': xs)) $ chunksOf (natVal' @(Product (xx ': xs))) xs
type Subtensor index dims e = Tensor (NormalizeDims (SubtensorDims index dims)) e
type SubtensorsElemsNumbers (dims :: [Nat]) = Tail (SubtensorsElemsNumbers' dims)
type family SubtensorsElemsNumbers' (dims :: [Nat]) :: [Nat] where
SubtensorsElemsNumbers' '[] = '[1]
SubtensorsElemsNumbers' (d ': ds) = SubtensorsElemsNumbers'' d (SubtensorsElemsNumbers' ds)
type family SubtensorsElemsNumbers'' (dim :: Nat) (dims :: [Nat]) :: [Nat] where
SubtensorsElemsNumbers'' d (q ': qs) = d * q ': q ': qs
type family SubtensorDims (index :: [Nat]) (dims :: [Nat]) :: [Nat] where
SubtensorDims '[] ds = ds
SubtensorDims (_ ': _ ) '[] = TypeError ('Text "SubtensorDims: Too many dimensions in the index for subtensor.")
SubtensorDims (i ': is) (d ': ds) =
If (i <=? d 1)
(1 ': SubtensorDims is ds)
(TypeError
('Text "SubtensorDims: Index "
':<>: 'ShowType i
':<>: 'Text " is outside of the range of dimension [0.."
':<>: 'ShowType (d 1)
':<>: 'Text "]."))
type family SubtensorStartIndex (index :: [Nat]) (dims :: [Nat]) :: [Nat] where
SubtensorStartIndex '[] '[] = '[]
SubtensorStartIndex (i ': is) '[] = TypeError ('Text "SubtensorStartIndex: Too many dimensions in the index for subtensor.")
SubtensorStartIndex '[] (d ': ds) = 0 ': SubtensorStartIndex '[] ds
SubtensorStartIndex (i ': is) (d ': ds) =
If (i <=? d 1)
(i ': SubtensorStartIndex is ds)
(TypeError
('Text "SubtensorStartIndex: Index "
':<>: 'ShowType i
':<>: 'Text " is outside of the range of dimension [0.."
':<>: 'ShowType (d 1)
':<>: 'Text "]."))
getSubtensor :: forall (index :: [Nat]) (dims :: [Nat]) e.
(GetSubtensor index dims e)
=> Tensor dims e
-> Subtensor index dims e
getSubtensor = getSlice @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
type GetSubtensor index dims e =
( GetSlice (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
)
getSubtensorElems :: forall (index :: [Nat]) (dims :: [Nat]) e.
(GetSubtensorElems index dims e)
=> Tensor dims e
-> [e]
getSubtensorElems = getSliceElems @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
type GetSubtensorElems index dims e =
GetSliceElems (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
setSubtensor :: forall (index :: [Nat]) (dims :: [Nat]) e.
(SetSubtensor index dims e)
=> Tensor dims e
-> Subtensor index dims e
-> Tensor dims e
setSubtensor = setSlice @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
type SetSubtensor index dims e =
SetSlice (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
setSubtensorElems :: forall (index :: [Nat]) (dims :: [Nat]) e.
(SetSubtensorElems index dims e)
=> Tensor dims e
-> [e]
-> Maybe (Tensor dims e)
setSubtensorElems = setSliceElems @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
type SetSubtensorElems index dims e =
SetSliceElems (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
mapSubtensorElems :: forall (index :: [Nat]) (dims :: [Nat]) e.
(MapSubtensorElems index dims e)
=> Tensor dims e
-> (e -> e)
-> Tensor dims e
mapSubtensorElems = mapSliceElems @(SubtensorStartIndex index dims) @(SubtensorDims index dims) @dims @e
type MapSubtensorElems index dims e =
MapSliceElems (SubtensorStartIndex index dims) (SubtensorDims index dims) dims e
subtensor :: forall (index :: [Nat]) (dims :: [Nat]) e.
(SubtensorCtx index dims e)
=> Lens' (Tensor dims e) (Subtensor index dims e)
subtensor = lens (getSubtensor @index @dims @e) (setSubtensor @index @dims @e)
type SubtensorCtx index dims e =
( GetSubtensor index dims e
, SetSubtensor index dims e)
tensorElem :: forall (index :: [Nat]) (dims :: [Nat]) e.
(TensorElem index dims e)
=> Lens' (Tensor dims e) e
tensorElem = subtensor @index @dims @e . (lens (\(Scalar a) -> a) (\_ b -> Scalar b))
type TensorElem index dims e =
( SubtensorCtx index dims e
, NormalizeDims (SubtensorDims index dims) ~ '[]
)
type family SliceEndIndex (startIndex :: [Nat]) (sliceDims :: [Nat]) (dims :: [Nat]) :: [Nat] where
SliceEndIndex '[] '[] '[] = '[]
SliceEndIndex '[] '[] (d ': ds) = TypeError ('Text "SliceEndIndex: Slice and its starting index have not enough dimensions.")
SliceEndIndex '[] (sd ': sds) '[] = TypeError ('Text "SliceEndIndex: Slice has too many dimensions.")
SliceEndIndex '[] (sd ': sds) (d ': ds) = TypeError ('Text "SliceEndIndex: Starting index of the slice has not enough dimensions.")
SliceEndIndex (si ': sis) '[] '[] = TypeError ('Text "SliceEndIndex: Starting index of the slice has too many dimensions.")
SliceEndIndex (si ': sis) '[] (d ': ds) = TypeError ('Text "SliceEndIndex: Slice has not enough dimensions.")
SliceEndIndex (si ': sis) (sd ': sds) '[] = TypeError ('Text "SliceEndIndex: Slice and its starting index have too many dimensions.")
SliceEndIndex (si ': sis) (sd ': sds) (d ': ds) = SliceEndIndex' (si ': sis) (sd ': sds) (d ': ds) (1 <=? sd)
type family SliceEndIndex' (startIndex :: [Nat]) (sliceDims :: [Nat]) (dims :: [Nat]) (sliceDimPositive :: Bool) :: [Nat] where
SliceEndIndex' (si ': sis) (sd ': sds) (d ': ds) 'True = SliceEndIndex'' (si ': sis) (sd ': sds) (d ': ds) (si + sd <=? d)
SliceEndIndex' _ (sd ': _) _ 'False =
TypeError ('Text "SliceEndIndex: Slice has non-positive dimension: " ':<>: 'ShowType sd)
type family SliceEndIndex'' (startIndex :: [Nat]) (sliceDims :: [Nat]) (dims :: [Nat]) (sliceDimInside :: Bool) :: [Nat] where
SliceEndIndex'' (si ': sis) (sd ': sds) (d ': ds) 'True = (si + sd 1 ': SliceEndIndex sis sds ds)
SliceEndIndex'' (si ': sis) (sd ': sds) (d ': ds) 'False =
(TypeError
( 'Text "SliceEndIndex: Slice dimension is outside of the tensor. It starts at "
':<>: 'ShowType si
':<>: 'Text " and ends at "
':<>: 'ShowType (si + sd 1)
':<>: 'Text " which is outside of the range of the tensor's dimension [0.."
':<>: 'ShowType (d 1)
':<>: 'Text "]."))
type ElemsInSlice (startIndex :: [Nat]) (sliceDims :: [Nat]) (dims :: [Nat]) =
ElemsInSlice' startIndex (SliceEndIndex startIndex sliceDims dims) (AllIndexes dims)
type family ElemsInSlice' (startIndex :: [Nat]) (endIndex :: [Nat]) (indexes :: [[Nat]]) :: [Bool] where
ElemsInSlice' _ _ '[] = '[]
ElemsInSlice' startIndex endIndex (i ': is) = ElemsInSlice'' i startIndex endIndex ': ElemsInSlice' startIndex endIndex is
type family ElemsInSlice'' (index :: [Nat]) (startIndex :: [Nat]) (endIndex :: [Nat]) :: Bool where
ElemsInSlice'' (i ': is) (s ': ss) (e ': es) = s <=? i && i <=? e && ElemsInSlice'' is ss es
ElemsInSlice'' '[] '[] '[] = 'True
slice :: forall startIndex sliceDims dims e.
(Slice startIndex sliceDims dims e)
=> Lens' (Tensor dims e) (Tensor (NormalizeDims sliceDims) e)
slice = lens (getSlice @startIndex @sliceDims @dims @e) (setSlice @startIndex @sliceDims @dims @e)
type Slice startIndex sliceDims dims e =
( IsTensor dims e
, IsTensor (NormalizeDims sliceDims) e
, GetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
, SetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
getSlice :: forall startIndex sliceDims dims e.
(GetSlice startIndex sliceDims dims e)
=> Tensor dims e
-> Tensor (NormalizeDims sliceDims) e
getSlice = unsafeFromList . getSliceElems @startIndex @sliceDims @dims @e
type GetSlice startIndex sliceDims dims e =
( IsTensor dims e
, IsTensor (NormalizeDims sliceDims) e
, GetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
getSliceElems :: forall startIndex sliceDims dims e.
(GetSliceElems startIndex sliceDims dims e)
=> Tensor dims e
-> [e]
getSliceElems = getSliceElemsWrk @(ElemsInSlice startIndex sliceDims dims) . toList
type GetSliceElems startIndex sliceDims dims e =
( IsTensor dims e
, GetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
impossible_notEnoughTensorElems :: a
impossible_notEnoughTensorElems =
error "Impossible happend! Not enough elements in the tensor. Please report this bug."
class GetSliceElemsWrk (elemsInSlice :: [Bool]) where
getSliceElemsWrk :: [e] -> [e]
instance GetSliceElemsWrk '[] where
getSliceElemsWrk _ = []
instance (GetSliceElemsWrk xs) => GetSliceElemsWrk ('True ': xs) where
getSliceElemsWrk [] = impossible_notEnoughTensorElems
getSliceElemsWrk (x : xs) = x : getSliceElemsWrk @xs xs
instance (GetSliceElemsWrk xs) => GetSliceElemsWrk ('False ': xs) where
getSliceElemsWrk [] = impossible_notEnoughTensorElems
getSliceElemsWrk (_ : xs) = getSliceElemsWrk @xs xs
setSlice :: forall startIndex sliceDims dims e.
(SetSlice startIndex sliceDims dims e)
=> Tensor dims e
-> Tensor (NormalizeDims sliceDims) e
-> Tensor dims e
setSlice t st =
case setSliceElems @startIndex @sliceDims @dims @e t $ toList st of
Nothing -> impossible_notEnoughTensorElems
Just x -> x
type SetSlice startIndex sliceDims dims e =
( IsTensor dims e
, IsTensor (NormalizeDims sliceDims) e
, SetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
setSliceElems :: forall startIndex sliceDims dims e.
(SetSliceElems startIndex sliceDims dims e)
=> Tensor dims e
-> [e]
-> Maybe (Tensor dims e)
setSliceElems t xs = unsafeFromList <$> setSliceElemsWrk @(ElemsInSlice startIndex sliceDims dims) (toList t) xs
type SetSliceElems startIndex sliceDims dims e =
( IsTensor dims e
, SetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
)
class SetSliceElemsWrk (elemsInSlice :: [Bool]) where
setSliceElemsWrk :: [e] -> [e] -> Maybe [e]
instance SetSliceElemsWrk '[] where
setSliceElemsWrk _ _ = Just []
instance (SetSliceElemsWrk xs) => SetSliceElemsWrk ('True ': xs) where
setSliceElemsWrk [] _ = impossible_notEnoughTensorElems
setSliceElemsWrk _ [] = Nothing
setSliceElemsWrk (_ : xs) (y : ys) = (y :) <$> setSliceElemsWrk @xs xs ys
instance (SetSliceElemsWrk xs) => SetSliceElemsWrk ('False ': xs) where
setSliceElemsWrk [] _ = impossible_notEnoughTensorElems
setSliceElemsWrk (x : xs) yss = (x :) <$> setSliceElemsWrk @xs xs yss
mapSliceElems :: forall startIndex sliceDims dims e.
(MapSliceElems startIndex sliceDims dims e)
=> Tensor dims e
-> (e -> e)
-> Tensor dims e
mapSliceElems t f =
case setSliceElems @startIndex @sliceDims @dims @e
t (U.map @(ElemsNumber sliceDims) f (getSliceElems @startIndex @sliceDims @dims @e t))
of
Nothing -> impossible_notEnoughTensorElems
Just x -> x
type MapSliceElems startIndex sliceDims dims e =
( IsTensor dims e
, GetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
, SetSliceElemsWrk (ElemsInSlice startIndex sliceDims dims)
, U.Map (ElemsNumber sliceDims)
)
remove :: forall (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) e.
(Remove axis indexOnAxis dims e)
=> Tensor dims e
-> Tensor (DimsAfterRemove axis indexOnAxis dims) e
remove = unsafeFromList . removeWrk @(ElemsInSlice (RemoveSliceStartIndex axis indexOnAxis dims) (RemoveSliceDims axis indexOnAxis dims) dims) . toList
type Remove (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) e =
( IsTensor dims e
, IsTensor (DimsAfterRemove axis indexOnAxis dims) e
, RemoveWrk (ElemsInSlice (RemoveSliceStartIndex axis indexOnAxis dims) (RemoveSliceDims axis indexOnAxis dims) dims)
)
type family DimsAfterRemove (axis :: Nat) (index :: Nat) (dims :: [Nat]) :: [Nat] where
DimsAfterRemove _ _ '[] = TypeError ('Text "DimsAfterRemove: axis must be in range [0..(number of dimensions in the tensor)].")
DimsAfterRemove 0 i (d ': ds) =
If (i <=? d 1)
(d 1 ': ds)
(TypeError (
'Text "DimsAfterRemove: Index "
':<>: 'ShowType i
':<>: 'Text " is outside of the range of dimension [0.."
':<>: 'ShowType (d 1)
':<>: 'Text "]."))
DimsAfterRemove a i (d ': ds) = d ': DimsAfterRemove (a 1) i ds
type RemoveSliceStartIndex (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) = RemoveSliceStartIndex' axis indexOnAxis dims 0
type family RemoveSliceStartIndex' (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) (n :: Nat) :: [Nat] where
RemoveSliceStartIndex' _ _ '[] _ = '[]
RemoveSliceStartIndex' a i (d ': ds) n =
If (a == n) i 0 ': RemoveSliceStartIndex' a i ds (n + 1)
type family RemoveSliceDims (axis :: Nat) (indexOnAxis :: Nat) (dims :: [Nat]) :: [Nat] where
RemoveSliceDims _ _ '[] = TypeError ('Text "RemoveSliceDims: axis must be in range [0..(number of dimensions in the tensor)].")
RemoveSliceDims 0 i (d ': ds) =
If (i <=? d 1)
(1 ': ds)
(TypeError (
'Text "RemoveSliceDims: Index "
':<>: 'ShowType i
':<>: 'Text " is outside of the range of dimension [0.."
':<>: 'ShowType (d 1)
':<>: 'Text "]."))
RemoveSliceDims a i (d ': ds) = d ': RemoveSliceDims (a 1) i ds
class RemoveWrk (elemsInSlice :: [Bool]) where
removeWrk :: [e] -> [e]
instance RemoveWrk '[] where
removeWrk _ = []
instance (RemoveWrk xs) => RemoveWrk ('False ': xs) where
removeWrk [] = impossible_notEnoughTensorElems
removeWrk (x : xs) = x : removeWrk @xs xs
instance (RemoveWrk xs) => RemoveWrk ('True ': xs) where
removeWrk [] = impossible_notEnoughTensorElems
removeWrk (_ : xs) = removeWrk @xs xs
cons :: forall (axis :: Nat) (dims :: [Nat]) e.
(Cons axis dims e) =>
Tensor (NormalizeDims (ConsSubtensorDims axis dims)) e
-> Tensor dims e
-> Tensor (DimsAfterCons axis dims) e
cons st t =
setSlice @(ConsSubtensorStartingIndex dims) @(ConsSubtensorDims axis dims) @(DimsAfterCons axis dims) @e t' st
where
t' = setSlice @(ConsTensorStartingIndex axis dims) @dims @(DimsAfterCons axis dims) z t
z = fill @(DimsAfterCons axis dims) @e (head $ toList t)
type Cons (axis :: Nat) (dims :: [Nat]) e =
( SetSlice (ConsSubtensorStartingIndex dims) (ConsSubtensorDims axis dims) (DimsAfterCons axis dims) e
, SetSlice (ConsTensorStartingIndex axis dims) dims (DimsAfterCons axis dims) e
, dims ~ NormalizeDims dims
, Fill (DimsAfterCons axis dims) e
)
type family ConsSubtensorStartingIndex (dims :: [Nat]) :: [Nat] where
ConsSubtensorStartingIndex '[] = '[]
ConsSubtensorStartingIndex (_ ': ds) = 0 ': ConsSubtensorStartingIndex ds
type ConsTensorStartingIndex (axis :: Nat) (dims :: [Nat]) = ConsTensorStartingIndex' axis dims 0
type family ConsTensorStartingIndex' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
ConsTensorStartingIndex' _ '[] _ = '[]
ConsTensorStartingIndex' a (d ': ds) i =
If (a == i) 1 0 ': ConsTensorStartingIndex' a ds (i + 1)
type ConsSubtensorDims (axis :: Nat) (dims :: [Nat]) = ConsSubtensorDims' axis dims 0
type family ConsSubtensorDims' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
ConsSubtensorDims' _ '[] _ = '[]
ConsSubtensorDims' a (d ': ds) i =
If (a == i) 1 d ': ConsSubtensorDims' a ds (i + 1)
type family DimsAfterCons (axis :: Nat) (dims :: [Nat]) :: [Nat] where
DimsAfterCons 0 (d ': ds) = d + 1 ': ds
DimsAfterCons a (d ': ds) = d ': DimsAfterCons (a 1) ds
DimsAfterCons _ '[] = TypeError ('Text "DimsAfterCons: axis must be in range [0..(number of dimensions in the tensor)].")
snoc :: forall (axis :: Nat) (dims :: [Nat]) e.
(Snoc axis dims e) =>
Tensor dims e
-> Tensor (NormalizeDims (SnocSubtensorDims axis dims)) e
-> Tensor (DimsAfterSnoc axis dims) e
snoc t st =
setSlice @(SnocSubtensorStartingIndex axis dims) @(SnocSubtensorDims axis dims) @(DimsAfterSnoc axis dims) @e t' st
where
t' = setSlice @(SnocTensorStartingIndex dims) @dims @(DimsAfterSnoc axis dims) z t
z = fill @(DimsAfterSnoc axis dims) @e (head $ toList t)
type Snoc (axis :: Nat) (dims :: [Nat]) e =
( SetSlice (SnocSubtensorStartingIndex axis dims) (SnocSubtensorDims axis dims) (DimsAfterSnoc axis dims) e
, SetSlice (SnocTensorStartingIndex dims) dims (DimsAfterSnoc axis dims) e
, dims ~ NormalizeDims dims
, Fill (DimsAfterSnoc axis dims) e
)
type family SnocTensorStartingIndex (dims :: [Nat]) :: [Nat] where
SnocTensorStartingIndex '[] = '[]
SnocTensorStartingIndex (_ ': ds) = 0 ': SnocTensorStartingIndex ds
type SnocSubtensorStartingIndex (axis :: Nat) (dims :: [Nat]) = SnocSubtensorStartingIndex' axis dims 0
type family SnocSubtensorStartingIndex' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
SnocSubtensorStartingIndex' _ '[] _ = '[]
SnocSubtensorStartingIndex' a (d ': ds) i =
If (a == i) d 0 ': SnocSubtensorStartingIndex' a ds (i + 1)
type SnocSubtensorDims (axis :: Nat) (dims :: [Nat]) = SnocSubtensorDims' axis dims 0
type family SnocSubtensorDims' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
SnocSubtensorDims' _ '[] _ = '[]
SnocSubtensorDims' a (d ': ds) i =
If (a == i) 1 d ': SnocSubtensorDims' a ds (i + 1)
type family DimsAfterSnoc (axis :: Nat) (dims :: [Nat]) :: [Nat] where
DimsAfterSnoc 0 (d ': ds) = d + 1 ': ds
DimsAfterSnoc a (d ': ds) = d ': DimsAfterSnoc (a 1) ds
DimsAfterSnoc _ '[] = TypeError ('Text "DimsAfterSnoc: axis must be in range [0..(number of dimensions in the tensor)].")
append :: forall (axis :: Nat) (dims0 :: [Nat]) (dims1 :: [Nat]) e.
(Append axis dims0 dims1 e)
=> Tensor dims0 e
-> Tensor dims1 e
-> Tensor (DimsAfterAppend axis dims0 dims1) e
append t0 t1 =
setSlice @(AppendSndTensorStartingIndex axis dims1) @dims1 @(DimsAfterAppend axis dims0 dims1) @e t0' t1
where
t0' = setSlice @(AppendFstTensorStartingIndex dims0) @dims0 @(DimsAfterAppend axis dims0 dims1) z t0
z = fill @(DimsAfterAppend axis dims0 dims1) @e (head $ toList t0)
type Append (axis :: Nat) (dims0 :: [Nat]) (dims1 :: [Nat]) e =
( SetSlice (AppendFstTensorStartingIndex dims0) dims0 (DimsAfterAppend axis dims0 dims1) e
, SetSlice (AppendSndTensorStartingIndex axis dims1) dims1 (DimsAfterAppend axis dims0 dims1) e
, dims0 ~ NormalizeDims dims0
, dims1 ~ NormalizeDims dims1
, Fill (DimsAfterAppend axis dims0 dims1) e
)
type family AppendFstTensorStartingIndex (dims :: [Nat]) :: [Nat] where
AppendFstTensorStartingIndex '[] = '[]
AppendFstTensorStartingIndex (_ ': ds) = 0 ': SnocTensorStartingIndex ds
type AppendSndTensorStartingIndex (axis :: Nat) (dims :: [Nat]) = AppendSndTensorStartingIndex' axis dims 0
type family AppendSndTensorStartingIndex' (axis :: Nat) (dims :: [Nat]) (i :: Nat) :: [Nat] where
AppendSndTensorStartingIndex' _ '[] _ = '[]
AppendSndTensorStartingIndex' a (d ': ds) i =
If (a == i) d 0 ': AppendSndTensorStartingIndex' a ds (i + 1)
type DimsAfterAppend (axis :: Nat) (dims0 :: [Nat]) (dims1 :: [Nat]) = DimsAfterAppend' axis dims0 dims1 0
type family DimsAfterAppend' (axis :: Nat) (dims0 :: [Nat]) (dims1 :: [Nat]) (i :: Nat) :: [Nat] where
DimsAfterAppend' _ '[] (d1 ': d1s) _ = TypeError ('Text "DimsAfterAppend: Tensors must have the same number of dimensions.")
DimsAfterAppend' _ (d0 ': d0s) '[] _ = TypeError ('Text "DimsAfterAppend: Tensors must have the same number of dimensions.")
DimsAfterAppend' a '[] '[] a = TypeError ('Text "DimsAfterAppend: axis must be in range [0..(number of dimensions in the tensor)].")
DimsAfterAppend' a '[] '[] i = '[]
DimsAfterAppend' a (d0 ': d0s) (d1 ': d1s) a = d0 + d1 ': DimsAfterAppend' a d0s d1s (a + 1)
DimsAfterAppend' a (d ': d0s) (d ': d1s) i = d ': DimsAfterAppend' a d0s d1s (i + 1)
DimsAfterAppend' a (d0 ': d0s) (d1 ': d1s) i = TypeError ('Text "DimsAfterAppend: Tensors have incompatible dimensions.")
instance (IsTensor dims a, IsTensor dims b) => Each (Tensor dims a) (Tensor dims b) a b where
each f t = unsafeFromList <$> traversed f (toList t)
type instance Element (Tensor dims e) = e
instance (MonoFunctorCtx dims e) => MonoFunctor (Tensor dims e) where
omap f = unsafeFromList . U.map @(ElemsNumber dims) f . toList
type MonoFunctorCtx (dims :: [Nat]) e =
( IsTensor dims e
, U.Map (ElemsNumber dims)
)
instance (MonoFoldableCtx dims e) => MonoFoldable (Tensor dims e) where
ofoldr f z = U.foldr @(ElemsNumber dims) f z . toList
ofoldMap f = U.foldMap @(ElemsNumber dims) f . toList
ofoldl' f z = U.foldl @(ElemsNumber dims) f z . toList
ofoldr1Ex f = U.foldr1 @(ElemsNumber dims) f . toList
ofoldl1Ex' f = U.foldl1 @(ElemsNumber dims) f . toList
type MonoFoldableCtx (dims :: [Nat]) e =
( IsTensor dims e
, U.Foldr (ElemsNumber dims)
, U.Foldl (ElemsNumber dims)
, U.Foldr1 (ElemsNumber dims)
, U.Foldl1 (ElemsNumber dims)
)
instance (MonoTraversableCtx dims e) => MonoTraversable (Tensor dims e) where
otraverse f t = unsafeFromList <$> traverse f (toList t)
type MonoTraversableCtx (dims :: [Nat]) e =
( IsTensor dims e
, U.Map (ElemsNumber dims)
, U.Foldr (ElemsNumber dims)
, U.Foldl (ElemsNumber dims)
, U.Foldr1 (ElemsNumber dims)
, U.Foldl1 (ElemsNumber dims)
)
instance (MonoZipCtx dims e) => MonoZip (Tensor dims e) where
ozipWith f = \t1 t2 -> unsafeFromList $ U.zipWith @(ElemsNumber dims) f (toList t1) (toList t2)
ozip t1 t2 = U.zip @(ElemsNumber dims) (toList t1) (toList t2)
ounzip ps = (unsafeFromList es1, unsafeFromList es2)
where (es1, es2) = U.unzip @(ElemsNumber dims) ps
type MonoZipCtx (dims :: [Nat]) e =
( IsTensor dims e
, U.Map (ElemsNumber dims)
, U.ZipWith (ElemsNumber dims)
, U.Zip (ElemsNumber dims)
, U.Unzip (ElemsNumber dims)
)
instance (IsTensor dims e, Storable e, KnownNat (ElemsNumber dims)) => Storable (Tensor dims e) where
alignment _ = alignment (undefined :: e)
sizeOf _ = elemsNumber @dims * offsetDiff (undefined :: e) (undefined :: e)
peek p = unsafeFromList <$> mapM (\x -> peekByteOff p (x * size)) [0 .. count 1]
where
size = offsetDiff (undefined :: e) (undefined :: e)
count = elemsNumber @dims
poke p m = mapM_ (\(i, x) -> pokeByteOff p (size * i) x) $ zip [0 .. count 1] $ toList m
where
size = offsetDiff (undefined :: e) (undefined :: e)
count = elemsNumber @dims
unsafeWithTensorPtr :: (IsTensor dims e, Storable e, KnownNat (ElemsNumber dims)) => Tensor dims e -> (Ptr e -> IO a) -> IO a
unsafeWithTensorPtr t f = with t (f . castPtr)
padding :: (Storable a, Storable b) => a -> b -> Int
padding a b = (alignB sizeOf a) `mod` alignB
where alignB = alignment b
offsetDiff :: (Storable a, Storable b) => a -> b -> Int
offsetDiff a b = sizeOf a + padding a b