{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Tensor where

import Control.Arrow
import Control.Category
import qualified Numeric.Half as N
import Data.Complex
import Data.Finite
import Data.Kind
  ( Constraint,
    Type,
  )
import Data.Maybe
import Data.Proxy
import Data.Reflection
import Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as V
import Foreign.ForeignPtr
import Foreign.Storable
import GHC.Exts
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D hiding (select)
import Torch.HList
import Torch.Internal.Cast
import Torch.Internal.Class
  ( Castable (..),
    CppTuple2 (..),
    CppTuple3 (..),
    CppTuple4 (..),
  )
import qualified Torch.Internal.Type as ATen
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.Typed.Auxiliary
import Prelude hiding (id, (.))

class KnownShape (shape :: [Nat]) where
  shapeVal :: [Int]

instance KnownShape '[] where
  shapeVal :: [Int]
shapeVal = []

instance (KnownNat h, KnownShape t) => KnownShape (h ': t) where
  shapeVal :: [Int]
shapeVal = forall (n :: Nat). KnownNat n => Int
natValI @h Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @t

getFiniteI :: Finite n -> Int
getFiniteI :: forall (n :: Nat). Finite n -> Int
getFiniteI = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> (Finite n -> Integer) -> Finite n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Finite n -> Integer
forall (n :: Nat). Finite n -> Integer
getFinite

class KnownDType (dtype :: D.DType) where
  dtypeVal :: D.DType

instance KnownDType 'D.Bool where
  dtypeVal :: DType
dtypeVal = DType
D.Bool

instance KnownDType 'D.UInt8 where
  dtypeVal :: DType
dtypeVal = DType
D.UInt8

instance KnownDType 'D.Int8 where
  dtypeVal :: DType
dtypeVal = DType
D.Int8

instance KnownDType 'D.Int16 where
  dtypeVal :: DType
dtypeVal = DType
D.Int16

instance KnownDType 'D.Int32 where
  dtypeVal :: DType
dtypeVal = DType
D.Int32

instance KnownDType 'D.Int64 where
  dtypeVal :: DType
dtypeVal = DType
D.Int64

instance KnownDType 'D.Half where
  dtypeVal :: DType
dtypeVal = DType
D.Half

instance KnownDType 'D.Float where
  dtypeVal :: DType
dtypeVal = DType
D.Float

instance KnownDType 'D.Double where
  dtypeVal :: DType
dtypeVal = DType
D.Double

instance KnownDType 'D.ComplexHalf where
  dtypeVal :: DType
dtypeVal = DType
D.ComplexHalf

instance KnownDType 'D.ComplexFloat where
  dtypeVal :: DType
dtypeVal = DType
D.ComplexFloat

instance KnownDType 'D.ComplexDouble where
  dtypeVal :: DType
dtypeVal = DType
D.ComplexDouble

type family ComputeDType (dtype' :: dtype) :: D.DType where
  ComputeDType Bool = D.Bool
  ComputeDType D.Bool = D.Bool
  ComputeDType D.UInt8 = D.UInt8
  ComputeDType D.Int8 = D.Int8
  ComputeDType D.Int16 = D.Int16
  ComputeDType D.Int32 = D.Int32
  ComputeDType Int = D.Int64
  ComputeDType D.Int64 = D.Int64
  ComputeDType N.Half = D.Half
  ComputeDType D.Half = D.Half
  ComputeDType Float = D.Float
  ComputeDType D.Float = D.Float
  ComputeDType Double = D.Double
  ComputeDType D.Double = D.Double
  ComputeDType (Complex N.Half) = D.ComplexHalf
  ComputeDType D.ComplexHalf = D.ComplexHalf
  ComputeDType (Complex Float) = D.ComplexFloat
  ComputeDType D.ComplexFloat = D.ComplexFloat
  ComputeDType (Complex Double) = D.ComplexDouble
  ComputeDType D.ComplexDouble = D.ComplexDouble
  ComputeDType dtype' = TypeError (Text "Unsupported tensor type " :<>: ShowType dtype')

class KnownDevice (device :: (D.DeviceType, Nat)) where
  deviceVal :: D.Device

instance (KnownNat n) => KnownDevice '( 'D.CPU, n) where
  deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.CPU (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)

instance (KnownNat n) => KnownDevice '( 'D.CUDA, n) where
  deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.CUDA (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)

instance (KnownNat n) => KnownDevice '( 'D.MPS, n) where
  deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.MPS (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)

type Size = Type -> Type

type Shape = [Type -> Type]

type family ToNat (shape :: Size) :: Nat where
  ToNat (S1 ('MetaSel _ _ _ _) f) = ToNat f
  ToNat (D1 _ f) = ToNat f
  ToNat (C1 _ f) = ToNat f
  ToNat (l :*: r) = ToNat l + ToNat r
  ToNat (l :+: r) = If (ToNat l <=? ToNat r) (ToNat r) (ToNat l)
  ToNat (K1 R (Vector n _)) = n
  ToNat (K1 _ _) = 1
  ToNat U1 = 1
  ToNat (Vector n) = n
  ToNat a = ToNat (Rep (a ()))

type family ToNats (shape :: Shape) :: [Nat] where
  ToNats '[] = '[]
  ToNats (x ': xs) = ToNat x ': ToNats xs

type family FromNat (shape :: Nat) :: Size where
  FromNat n = Vector n

type family FromNats (shape :: [Nat]) :: Shape where
  FromNats '[] = '[]
  FromNats (x ': xs) = FromNat x ': FromNats xs

class Unnamed t where
  type UTShape t :: [Nat]
  type UTDevice t :: (D.DeviceType, Nat)
  type UTDType t :: D.DType
  toUnnamed ::
    forall device dtype shape.
    IsUnnamed t device dtype shape =>
    t ->
    Tensor device dtype shape
  fromUnnamed ::
    forall device dtype shape.
    IsUnnamed t device dtype shape =>
    Tensor device dtype shape ->
    t
  toDynamic ::
    t -> D.Tensor

type family IsUnnamed t (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: [Nat]) :: Constraint where
  IsUnnamed t device dtype shape =
    ( Unnamed t,
      device ~ (UTDevice t),
      dtype ~ (UTDType t),
      shape ~ (UTShape t)
    )

instance Unnamed (Tensor device dtype shape) where
  type UTShape (Tensor device dtype shape) = shape
  type UTDevice (Tensor device dtype shape) = device
  type UTDType (Tensor device dtype shape) = dtype
  toUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (Tensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> Tensor device dtype shape
toUnnamed = Tensor device dtype shape -> Tensor device dtype shape
Tensor device dtype shape -> Tensor device dtype shape
forall a. a -> a
forall {k} (cat :: k -> k -> Type) (a :: k).
Category cat =>
cat a a
id
  fromUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (Tensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> Tensor device dtype shape
fromUnnamed = Tensor device dtype shape -> Tensor device dtype shape
Tensor device dtype shape -> Tensor device dtype shape
forall a. a -> a
forall {k} (cat :: k -> k -> Type) (a :: k).
Category cat =>
cat a a
id
  toDynamic :: Tensor device dtype shape -> Tensor
toDynamic (UnsafeMkTensor Tensor
t) = Tensor
t

data Tensor (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: [Nat]) where
  UnsafeMkTensor :: forall device dtype shape. D.Tensor -> Tensor device dtype shape

type CPUTensor = Tensor '( 'D.CPU, 0)

type CUDATensor deviceIndex = Tensor '( 'D.CUDA, deviceIndex)

type MPSTensor deviceIndex = Tensor '( 'D.MPS, 0)

data UnknownShapeTensor device dtype = forall shape. UnknownShapeTensor (Tensor device dtype shape)

type family ComputeHaskellType (dtype :: D.DType) :: Type where
  ComputeHaskellType D.Bool = Bool
  ComputeHaskellType D.Int64 = Int
  ComputeHaskellType D.Float = Float
  ComputeHaskellType D.Double = Double
  ComputeHaskellType dtype = TypeError (Text "Unsupported tensor type " :<>: ShowType dtype)

type family ComputeItemType (ty :: Type) (shape :: [Nat]) :: Type where
  ComputeItemType _ '[] = TypeError (Text "Scalars are not supported")
  ComputeItemType ty (_ ': '[]) = ty
  ComputeItemType ty (_ ': h ': t) = [ComputeItemType ty (h ': t)]

instance
  ( D.TensorLike [ComputeItemType (ComputeHaskellType dtype) shape],
    KnownDevice device,
    KnownShape shape
  ) =>
  IsList (Maybe (Tensor device dtype shape))
  where
  type Item (Maybe (Tensor device dtype shape)) = ComputeItemType (ComputeHaskellType dtype) shape
  fromList :: [Item (Maybe (Tensor device dtype shape))]
-> Maybe (Tensor device dtype shape)
fromList [Item (Maybe (Tensor device dtype shape))]
xs = do
    shapeXs <- [ComputeItemType (ComputeHaskellType dtype) shape] -> Maybe [Int]
forall a. TensorLike a => a -> Maybe [Int]
D._deepDims [Item (Maybe (Tensor device dtype shape))]
[ComputeItemType (ComputeHaskellType dtype) shape]
xs
    if shapeVal @shape == shapeXs
      then return $ UnsafeMkTensor . D.toDevice (deviceVal @device) . D.asTensor $ xs
      else Nothing
  toList :: Maybe (Tensor device dtype shape)
-> [Item (Maybe (Tensor device dtype shape))]
toList Maybe (Tensor device dtype shape)
Nothing = []
  toList (Just Tensor device dtype shape
t) = Tensor -> [Item (Maybe (Tensor device dtype shape))]
Tensor -> [ComputeItemType (ComputeHaskellType dtype) shape]
forall a. TensorLike a => Tensor -> a
D.asValue (Tensor -> [Item (Maybe (Tensor device dtype shape))])
-> (Tensor device dtype shape -> Tensor)
-> Tensor device dtype shape
-> [Item (Maybe (Tensor device dtype shape))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (DeviceType -> Int16 -> Device
D.Device DeviceType
D.CPU Int16
0) (Tensor -> Tensor)
-> (Tensor device dtype shape -> Tensor)
-> Tensor device dtype shape
-> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (Tensor device dtype shape
 -> [Item (Maybe (Tensor device dtype shape))])
-> Tensor device dtype shape
-> [Item (Maybe (Tensor device dtype shape))]
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape
t

instance KnownDevice device => Num (Tensor device dtype shape) where
  + :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
(+) Tensor device dtype shape
a Tensor device dtype shape
b = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
  (-) Tensor device dtype shape
a Tensor device dtype shape
b = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
  * :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
(*) Tensor device dtype shape
a Tensor device dtype shape
b = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
  negate :: Tensor device dtype shape -> Tensor device dtype shape
negate Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
forall a. Num a => a -> a
negate (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
  abs :: Tensor device dtype shape -> Tensor device dtype shape
abs Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
forall a. Num a => a -> a
abs (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
  signum :: Tensor device dtype shape -> Tensor device dtype shape
signum Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
forall a. Num a => a -> a
signum (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
  fromInteger :: Integer -> Tensor device dtype shape
fromInteger Integer
i = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> (Int -> Tensor) -> Int -> Tensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) (Tensor -> Tensor) -> (Int -> Tensor) -> Int -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor @Int (Int -> Tensor device dtype shape)
-> Int -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger @Int Integer
i

instance KnownDevice device => Fractional (Tensor device dtype shape) where
  Tensor device dtype shape
a / :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
/ Tensor device dtype shape
b = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a Tensor -> Tensor -> Tensor
forall a. Fractional a => a -> a -> a
/ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
  recip :: Tensor device dtype shape -> Tensor device dtype shape
recip Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
forall a. Fractional a => a -> a
recip (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
  fromRational :: Rational -> Tensor device dtype shape
fromRational Rational
i = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> (Float -> Tensor) -> Float -> Tensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) (Tensor -> Tensor) -> (Float -> Tensor) -> Float -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor @Float (Float -> Tensor device dtype shape)
-> Float -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ forall a. Fractional a => Rational -> a
fromRational @Float Rational
i

instance Show (Tensor device dtype shape) where
  show :: Tensor device dtype shape -> String
show (UnsafeMkTensor Tensor
dynamic) = Tensor -> String
forall a. Show a => a -> String
show Tensor
dynamic

class TensorOptions (shape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where
  optionsRuntimeShape :: [Int]
  optionsRuntimeDType :: D.DType
  optionsRuntimeDevice :: D.Device

instance (KnownDType dtype, KnownDevice device) => TensorOptions '[] dtype device where
  optionsRuntimeShape :: [Int]
optionsRuntimeShape = []
  optionsRuntimeDType :: DType
optionsRuntimeDType = forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype
  optionsRuntimeDevice :: Device
optionsRuntimeDevice = forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device

instance (KnownNat h, TensorOptions t dtype device) => TensorOptions (h ': t) dtype device where
  optionsRuntimeShape :: [Int]
optionsRuntimeShape = forall (n :: Nat). KnownNat n => Int
natValI @h Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @t @dtype @device
  optionsRuntimeDType :: DType
optionsRuntimeDType = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @t @dtype @device
  optionsRuntimeDevice :: Device
optionsRuntimeDevice = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @t @dtype @device

--------------------------------------------------------------------------------
-- Untyped -> Typed typecasts
--------------------------------------------------------------------------------

type family All (pred :: a -> Constraint) (l :: [a]) :: Constraint where
  All _ '[] = ()
  All pred (h ': t) = (pred h, All pred t)

data SomeShape where
  SomeShape :: forall (shape :: [Nat]). KnownShape shape => Proxy shape -> SomeShape

someShape :: [Int] -> SomeShape
someShape :: [Int] -> SomeShape
someShape [] = Proxy '[] -> SomeShape
forall (shape :: [Nat]).
KnownShape shape =>
Proxy shape -> SomeShape
SomeShape (Proxy '[] -> SomeShape) -> Proxy '[] -> SomeShape
forall a b. (a -> b) -> a -> b
$ forall (t :: [Nat]). Proxy t
forall {k} (t :: k). Proxy t
Proxy @'[]
someShape (Int
h : [Int]
t) = case Integer -> Maybe SomeNat
someNatVal (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
h) of
  Maybe SomeNat
Nothing -> String -> SomeShape
forall a. HasCallStack => String -> a
error String
"Negative dimension in someShape!"
  (Just (SomeNat (Proxy n
Proxy :: Proxy ht))) -> case [Int] -> SomeShape
someShape [Int]
t of
    (SomeShape (Proxy shape
Proxy :: Proxy tt)) -> Proxy (n : shape) -> SomeShape
forall (shape :: [Nat]).
KnownShape shape =>
Proxy shape -> SomeShape
SomeShape (Proxy (n : shape) -> SomeShape) -> Proxy (n : shape) -> SomeShape
forall a b. (a -> b) -> a -> b
$ forall (t :: [Nat]). Proxy t
forall {k} (t :: k). Proxy t
Proxy @(ht ': tt)

data SomeDType where
  SomeDType :: forall (dtype :: D.DType). KnownDType dtype => Proxy dtype -> SomeDType

someDType :: D.DType -> SomeDType
someDType :: DType -> SomeDType
someDType DType
D.Bool = Proxy 'Bool -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Bool -> SomeDType) -> Proxy 'Bool -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Bool
someDType DType
D.UInt8 = Proxy 'UInt8 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'UInt8 -> SomeDType) -> Proxy 'UInt8 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.UInt8
someDType DType
D.Int8 = Proxy 'Int8 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Int8 -> SomeDType) -> Proxy 'Int8 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Int8
someDType DType
D.Int16 = Proxy 'Int16 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Int16 -> SomeDType) -> Proxy 'Int16 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Int16
someDType DType
D.Int32 = Proxy 'Int32 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Int32 -> SomeDType) -> Proxy 'Int32 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Int32
someDType DType
D.Int64 = Proxy 'Int64 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Int64 -> SomeDType) -> Proxy 'Int64 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Int64
someDType DType
D.Half = Proxy 'Half -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Half -> SomeDType) -> Proxy 'Half -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Half
someDType DType
D.Float = Proxy 'Float -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Float -> SomeDType) -> Proxy 'Float -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Float
someDType DType
D.Double = Proxy 'Double -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Double -> SomeDType) -> Proxy 'Double -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Double

data SomeDevice where
  SomeDevice :: forall (device :: (D.DeviceType, Nat)). KnownDevice device => Proxy device -> SomeDevice

someDevice :: D.Device -> SomeDevice
someDevice :: Device -> SomeDevice
someDevice D.Device {Int16
DeviceType
deviceType :: DeviceType
deviceIndex :: Int16
deviceIndex :: Device -> Int16
deviceType :: Device -> DeviceType
..} = case Integer -> Maybe SomeNat
someNatVal (Int16 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
deviceIndex) of
  Maybe SomeNat
Nothing -> String -> SomeDevice
forall a. HasCallStack => String -> a
error String
"Negative device index in someDevice!"
  Just (SomeNat (Proxy n
Proxy :: Proxy n)) -> case DeviceType
deviceType of
    DeviceType
D.CPU -> Proxy '( 'CPU, n) -> SomeDevice
forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice (Proxy '( 'CPU, n) -> SomeDevice)
-> Proxy '( 'CPU, n) -> SomeDevice
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: (DeviceType, Nat)). Proxy t
Proxy @'( 'D.CPU, n)
    DeviceType
D.CUDA -> Proxy '( 'CUDA, n) -> SomeDevice
forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice (Proxy '( 'CUDA, n) -> SomeDevice)
-> Proxy '( 'CUDA, n) -> SomeDevice
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: (DeviceType, Nat)). Proxy t
Proxy @'( 'D.CUDA, n)
    DeviceType
D.MPS -> Proxy '( 'MPS, n) -> SomeDevice
forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice (Proxy '( 'MPS, n) -> SomeDevice)
-> Proxy '( 'MPS, n) -> SomeDevice
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: (DeviceType, Nat)). Proxy t
Proxy @'( 'D.MPS, n)

withTensor ::
  D.Tensor ->
  ( forall shape dtype device.
    ( KnownDevice device,
      KnownDType dtype,
      KnownShape shape
    ) =>
    Tensor device dtype shape ->
    r
  ) ->
  r
withTensor :: forall r.
Tensor
-> (forall (shape :: [Nat]) (dtype :: DType)
           (device :: (DeviceType, Nat)).
    (KnownDevice device, KnownDType dtype, KnownShape shape) =>
    Tensor device dtype shape -> r)
-> r
withTensor Tensor
untypedTensor forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r
f = case [Int] -> SomeShape
someShape (Tensor -> [Int]
D.shape Tensor
untypedTensor) of
  (SomeShape (Proxy shape
Proxy :: Proxy shape)) -> case DType -> SomeDType
someDType (Tensor -> DType
D.dtype Tensor
untypedTensor) of
    (SomeDType (Proxy dtype
Proxy :: Proxy dtype)) -> case Device -> SomeDevice
someDevice (Tensor -> Device
D.device Tensor
untypedTensor) of
      (SomeDevice (Proxy device
Proxy :: Proxy device)) -> Tensor device dtype shape -> r
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r
f (Tensor device dtype shape -> r) -> Tensor device dtype shape -> r
forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @shape Tensor
untypedTensor

withTensorShape ::
  forall device dtype r.
  ( KnownDevice device,
    KnownDType dtype
  ) =>
  D.Tensor ->
  ( forall shape.
    KnownShape shape =>
    Tensor device dtype shape ->
    r
  ) ->
  r
withTensorShape :: forall (device :: (DeviceType, Nat)) (dtype :: DType) r.
(KnownDevice device, KnownDType dtype) =>
Tensor
-> (forall (shape :: [Nat]).
    KnownShape shape =>
    Tensor device dtype shape -> r)
-> r
withTensorShape Tensor
untypedTensor forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r
f = case [Int] -> SomeShape
someShape (Tensor -> [Int]
D.shape Tensor
untypedTensor) of
  -- ToDo: check device/dtype of untyped tensor.
  (SomeShape (Proxy shape
Proxy :: Proxy shape)) -> Tensor device dtype shape -> r
forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r
f (Tensor device dtype shape -> r) -> Tensor device dtype shape -> r
forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @shape Tensor
untypedTensor

--------------------------------------------------------------------------------
-- Broadcast type-level function
--------------------------------------------------------------------------------

type family ComputeBroadcast (reversedShape :: [Nat]) (reversedShape' :: [Nat]) :: Maybe [Nat] where
  ComputeBroadcast '[] reversedShape = Just reversedShape
  ComputeBroadcast reversedShape '[] = Just reversedShape
  ComputeBroadcast (h ': t) (h ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
  ComputeBroadcast (h ': t) (1 ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
  ComputeBroadcast (1 ': t) (h ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
  ComputeBroadcast _ _ = Nothing

type family CheckBroadcast (shape :: [Nat]) (shape' :: [Nat]) (result :: Maybe [Nat]) :: [Nat] where
  CheckBroadcast shape shape' Nothing =
    TypeError
      ( Text "The shapes "
          :<>: ShowType shape
          :<>: Text " and "
          :<>: ShowType shape'
          :<>: Text " cannot be broadcast"
      )
  CheckBroadcast _ _ (Just result) = (Reverse result)

type Broadcast shape shape' =
  CheckBroadcast
    shape
    shape'
    ( ComputeBroadcast
        (Reverse shape)
        (Reverse shape')
    )

type family BasicArithmeticDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  BasicArithmeticDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsNotBool '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  BasicArithmeticDTypeIsValid '( 'D.CUDA, _) dtype = ()
  BasicArithmeticDTypeIsValid '( 'D.MPS, 0) dtype = ()
  BasicArithmeticDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

add,
  sub,
  mul,
  div ::
    forall shape'' shape shape' dtype dtype' dtype'' device.
    ( dtype'' ~ DTypePromotion dtype dtype',
      shape'' ~ Broadcast shape shape',
      BasicArithmeticDTypeIsValid device dtype,
      BasicArithmeticDTypeIsValid device dtype',
      BasicArithmeticDTypeIsValid device dtype''
    ) =>
    Tensor device dtype shape ->
    Tensor device dtype' shape' ->
    Tensor device dtype'' shape''
add :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
add Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device dtype'' shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype'' shape'')
-> Tensor -> Tensor device dtype'' shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.add (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
sub :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
sub Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device dtype'' shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype'' shape'')
-> Tensor -> Tensor device dtype'' shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.sub (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
mul :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device dtype'' shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype'' shape'')
-> Tensor -> Tensor device dtype'' shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.mul (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
div :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
div Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device dtype'' shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype'' shape'')
-> Tensor -> Tensor device dtype'' shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.div (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)

type family ComparisonDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  ComparisonDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsNotBool '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  ComparisonDTypeIsValid '( 'D.CUDA, _) dtype = ()
  ComparisonDTypeIsValid '( 'D.MPS, 0) dtype = ()
  ComparisonDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

gt,
  lt,
  ge,
  le,
  eq,
  ne,
  (>.),
  (<.),
  (>=.),
  (<=.),
  (==.),
  (/=.) ::
    forall shape'' shape shape' dtype dtype' device.
    ( shape'' ~ Broadcast shape shape',
      ComparisonDTypeIsValid device dtype,
      ComparisonDTypeIsValid device dtype'
    ) =>
    Tensor device dtype shape ->
    Tensor device dtype' shape' ->
    Tensor device 'D.Bool shape''
gt :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
gt Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.gt (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
lt :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
lt Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.lt (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
ge :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ge Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.ge (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
le :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
le Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.le (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
eq :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
eq Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.eq (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
ne :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ne Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.ne (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
>. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(>.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
gt
<. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(<.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
lt
>=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(>=.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ge
<=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(<=.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
le
==. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(==.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
eq
/=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(/=.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ne

type family ComputeMatMul (reversedShape :: [Nat]) (reversedShape' :: [Nat]) :: Maybe [Nat] where
  ComputeMatMul (k ': '[]) (k ': '[]) = Just '[]
  ComputeMatMul (k ': '[]) (m ': k ': reversedBroadcastShape') = AppendToMaybe m (ComputeBroadcast '[] reversedBroadcastShape')
  ComputeMatMul (k ': n ': reversedBroadcastShape) (k ': '[]) = AppendToMaybe n (ComputeBroadcast '[] reversedBroadcastShape)
  ComputeMatMul (k ': n ': reversedBroadcastShape) (m ': k ': reversedBroadcastShape') = AppendToMaybe m (AppendToMaybe n (ComputeBroadcast reversedBroadcastShape reversedBroadcastShape'))

type family CheckMatMul (shape :: [Nat]) (shape' :: [Nat]) (result :: Maybe [Nat]) :: [Nat] where
  CheckMatMul shape shape' Nothing =
    TypeError
      ( Text "The shapes "
          :<>: ShowType shape
          :<>: Text " and "
          :<>: ShowType shape'
          :<>: Text " are not compatible with matrix multiplication"
      )
  CheckMatMul _ _ (Just result) = (Reverse result)

type MatMul shape shape' = CheckMatMul shape shape' (ComputeMatMul (Reverse shape) (Reverse shape'))

type family MatMulDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
  MatMulDTypeIsValid '( 'D.CPU, 0) dtype =
    ( DTypeIsNotBool '( 'D.CPU, 0) dtype,
      DTypeIsNotHalf '( 'D.CPU, 0) dtype
    )
  MatMulDTypeIsValid '( 'D.CUDA, deviceIndex) dtype = DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype
  MatMulDTypeIsValid '( 'D.MPS, 0) dtype = DTypeIsFloatingPoint '( 'D.MPS, 0) dtype
  MatMulDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype

-- | matrix multiplication
-- See https://pytorch.org/docs/stable/torch.html#torch.matmul.
matmul ::
  forall shape'' shape shape' dtype device.
  ( shape'' ~ MatMul shape shape',
    MatMulDTypeIsValid device dtype
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape' ->
  Tensor device dtype shape''
matmul :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor device dtype shape
a Tensor device dtype shape'
b = Tensor -> Tensor device dtype shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape'')
-> Tensor -> Tensor device dtype shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.matmul (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape'
b)

select ::
  forall dim idx shape' shape dtype device.
  ( KnownNat dim,
    KnownNat idx,
    InRange shape dim idx,
    shape' ~ Remove shape dim
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape'
select :: forall (dim :: Nat) (idx :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownNat idx, InRange shape dim idx,
 shape' ~ Remove shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
select Tensor device dtype shape
t = Tensor -> Tensor device dtype shape'
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape')
-> Tensor -> Tensor device dtype shape'
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Tensor -> Tensor
D.select (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (n :: Nat). KnownNat n => Int
natValI @idx) (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

selectIdx ::
  forall dim n shape' shape dtype device.
  ( KnownNat dim,
    n ~ Index shape dim,
    shape' ~ Remove shape dim
  ) =>
  Tensor device dtype shape ->
  Finite n ->
  Tensor device dtype shape'
selectIdx :: forall (dim :: Nat) (n :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, n ~ Index shape dim, shape' ~ Remove shape dim) =>
Tensor device dtype shape -> Finite n -> Tensor device dtype shape'
selectIdx Tensor device dtype shape
t Finite n
idx = Tensor -> Tensor device dtype shape'
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape')
-> Tensor -> Tensor device dtype shape'
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Tensor -> Tensor
D.select (forall (n :: Nat). KnownNat n => Int
natValI @dim) (Finite n -> Int
forall (n :: Nat). Finite n -> Int
getFiniteI Finite n
idx) (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

type family Numel (shape :: [Nat]) :: Nat where
  Numel '[] = 1
  Numel (h ': t) = h * (Numel t)

-- | reshape
-- >>> t :: CPUTensor 'D.Int64 '[2,3,4] = fromJust [[[111,112,113,114],[121,122,123,124],[131,132,133,134]],[[211,212,213,214],[221,222,223,224],[231,232,233,234]]]
-- >>> t' = reshape @'[24] t
-- >>> toList . Just $ t'
-- [111,112,113,114,121,122,123,124,131,132,133,134,211,212,213,214,221,222,223,224,231,232,233,234]
-- >>> toList . Just $ reshape @'[2,3,4] t'
-- [[[111,112,113,114],[121,122,123,124],[131,132,133,134]],[[211,212,213,214],[221,222,223,224],[231,232,233,234]]]
reshape ::
  forall shape' shape dtype device.
  ( KnownShape shape',
    Numel shape ~ Numel shape'
  ) =>
  Tensor device dtype shape ->
  Tensor device dtype shape'
reshape :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape Tensor device dtype shape
t = Tensor -> Tensor device dtype shape'
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape')
-> Tensor -> Tensor device dtype shape'
forall a b. (a -> b) -> a -> b
$ [Int] -> Tensor -> Tensor
D.reshape (forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @shape') (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

-- | To avoid overlapped instance for (Unnamed t => Castable t D.ATenTensor)
newtype Wrap a = Wrap {forall a. Wrap a -> a
unWrap :: a}

instance {-# OVERLAPS #-} Unnamed t => Castable (Wrap t) D.ATenTensor where
  cast :: forall r. Wrap t -> (ATenTensor -> IO r) -> IO r
cast Wrap t
t ATenTensor -> IO r
f =
    let (D.Unsafe ATenTensor
aten_tensor) = t -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (Wrap t -> t
forall a. Wrap a -> a
unWrap Wrap t
t)
     in ATenTensor -> IO r
f ATenTensor
aten_tensor
  uncast :: forall r. ATenTensor -> (Wrap t -> IO r) -> IO r
uncast ATenTensor
aten_tensor Wrap t -> IO r
f = Wrap t -> IO r
f (Wrap t -> IO r) -> Wrap t -> IO r
forall a b. (a -> b) -> a -> b
$ t -> Wrap t
forall a. a -> Wrap a
Wrap (t -> Wrap t) -> t -> Wrap t
forall a b. (a -> b) -> a -> b
$ Tensor (UTDevice t) (UTDType t) (UTShape t) -> t
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed t device dtype shape =>
Tensor device dtype shape -> t
fromUnnamed (Tensor (UTDevice t) (UTDType t) (UTShape t) -> t)
-> Tensor (UTDevice t) (UTDType t) (UTShape t) -> t
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor (UTDevice t) (UTDType t) (UTShape t)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor)

instance Castable (NamedTensor device dtype shape) D.ATenTensor where
  cast :: forall r.
NamedTensor device dtype shape -> (ATenTensor -> IO r) -> IO r
cast (FromTensor (UnsafeMkTensor (D.Unsafe ATenTensor
aten_tensor))) ATenTensor -> IO r
f = ATenTensor -> IO r
f ATenTensor
aten_tensor
  uncast :: forall r.
ATenTensor -> (NamedTensor device dtype shape -> IO r) -> IO r
uncast ATenTensor
aten_tensor NamedTensor device dtype shape -> IO r
f = NamedTensor device dtype shape -> IO r
f (NamedTensor device dtype shape -> IO r)
-> (Tensor -> NamedTensor device dtype shape) -> Tensor -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape' :: Shape) (shape :: [Nat]).
(shape ~ ToNats shape') =>
Tensor device dtype shape -> NamedTensor device dtype shape'
FromTensor (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> (Tensor -> Tensor device dtype (ToNats shape))
-> Tensor
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor -> Tensor device dtype (ToNats shape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> IO r) -> Tensor -> IO r
forall a b. (a -> b) -> a -> b
$ ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor

instance Castable (Tensor device dtype shape) D.ATenTensor where
  cast :: forall r. Tensor device dtype shape -> (ATenTensor -> IO r) -> IO r
cast (UnsafeMkTensor (D.Unsafe ATenTensor
aten_tensor)) ATenTensor -> IO r
f = ATenTensor -> IO r
f ATenTensor
aten_tensor
  uncast :: forall r. ATenTensor -> (Tensor device dtype shape -> IO r) -> IO r
uncast ATenTensor
aten_tensor Tensor device dtype shape -> IO r
f = Tensor device dtype shape -> IO r
f (Tensor device dtype shape -> IO r)
-> Tensor device dtype shape -> IO r
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor)

instance Castable [Tensor device dtype shape] (ForeignPtr ATen.TensorList) where
  cast :: forall r.
[Tensor device dtype shape]
-> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor device dtype shape]
xs ForeignPtr TensorList -> IO r
f = do
    ptr_list <- (Tensor device dtype shape -> IO ATenTensor)
-> [Tensor device dtype shape] -> IO [ATenTensor]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (\Tensor device dtype shape
x -> (Tensor device dtype shape
-> (ATenTensor -> IO ATenTensor) -> IO ATenTensor
forall r. Tensor device dtype shape -> (ATenTensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor device dtype shape
x ATenTensor -> IO ATenTensor
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor device dtype shape]
xs
    cast ptr_list f
  uncast :: forall r.
ForeignPtr TensorList
-> ([Tensor device dtype shape] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor device dtype shape] -> IO r
f = ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall r. ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs (([ATenTensor] -> IO r) -> IO r) -> ([ATenTensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ATenTensor]
ptr_list -> do
    tensor_list <- (ATenTensor -> IO (Tensor device dtype shape))
-> [ATenTensor] -> IO [Tensor device dtype shape]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (\(ATenTensor
x :: ForeignPtr ATen.Tensor) -> ATenTensor
-> (Tensor device dtype shape -> IO (Tensor device dtype shape))
-> IO (Tensor device dtype shape)
forall r. ATenTensor -> (Tensor device dtype shape -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x Tensor device dtype shape -> IO (Tensor device dtype shape)
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return) [ATenTensor]
ptr_list
    f tensor_list

instance KnownNat n => Castable (Vector n (Tensor device dtype shape)) (ForeignPtr ATen.TensorList) where
  cast :: forall r.
Vector n (Tensor device dtype shape)
-> (ForeignPtr TensorList -> IO r) -> IO r
cast Vector n (Tensor device dtype shape)
xs ForeignPtr TensorList -> IO r
f = do
    ptr_list <- Vector n ATenTensor -> [ATenTensor]
forall (n :: Nat) a. Vector n a -> [a]
V.toList (Vector n ATenTensor -> [ATenTensor])
-> IO (Vector n ATenTensor) -> IO [ATenTensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype shape -> IO ATenTensor)
-> Vector n (Tensor device dtype shape) -> IO (Vector n ATenTensor)
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Vector Vector n a -> m (Vector Vector n b)
mapM (\Tensor device dtype shape
x -> (Tensor device dtype shape
-> (ATenTensor -> IO ATenTensor) -> IO ATenTensor
forall r. Tensor device dtype shape -> (ATenTensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor device dtype shape
x ATenTensor -> IO ATenTensor
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) Vector n (Tensor device dtype shape)
xs
    cast ptr_list f
  uncast :: forall r.
ForeignPtr TensorList
-> (Vector n (Tensor device dtype shape) -> IO r) -> IO r
uncast ForeignPtr TensorList
xs Vector n (Tensor device dtype shape) -> IO r
f = ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall r. ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs (([ATenTensor] -> IO r) -> IO r) -> ([ATenTensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ATenTensor]
ptr_list -> do
    tensor_list <- (ATenTensor -> IO (Tensor device dtype shape))
-> [ATenTensor] -> IO [Tensor device dtype shape]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (\(ATenTensor
x :: ForeignPtr ATen.Tensor) -> ATenTensor
-> (Tensor device dtype shape -> IO (Tensor device dtype shape))
-> IO (Tensor device dtype shape)
forall r. ATenTensor -> (Tensor device dtype shape -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x Tensor device dtype shape -> IO (Tensor device dtype shape)
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return) [ATenTensor]
ptr_list
    Just xs <- pure $ V.fromListN tensor_list
    f xs

data TensorListFold = TensorListFold

instance (Castable x D.ATenTensor) => Apply' TensorListFold (x, IO [D.ATenTensor]) (IO [D.ATenTensor]) where
  apply' :: TensorListFold -> (x, IO [ATenTensor]) -> IO [ATenTensor]
apply' TensorListFold
_ (x
x, IO [ATenTensor]
mxs) = do
    xs <- IO [ATenTensor]
mxs
    x' <- cast x return
    return (x' : xs)

data TensorListUnfold = TensorListUnfold

instance Apply TensorListUnfold [D.ATenTensor] (IO HNothing) where
  apply :: TensorListUnfold -> [ATenTensor] -> IO HNothing
apply TensorListUnfold
_ [] = HNothing -> IO HNothing
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure HNothing
HNothing

instance (Castable x D.ATenTensor) => Apply TensorListUnfold [D.ATenTensor] (IO (HJust (x, [D.ATenTensor]))) where
  apply :: TensorListUnfold -> [ATenTensor] -> IO (HJust (x, [ATenTensor]))
apply TensorListUnfold
_ (ATenTensor
x : [ATenTensor]
xs) = do
    x' <- ATenTensor -> (x -> IO x) -> IO x
forall r. ATenTensor -> (x -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x x -> IO x
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return
    return $ HJust (x', xs)

instance
  ( HFoldrM IO TensorListFold [D.ATenTensor] l [D.ATenTensor],
    Apply TensorListUnfold [D.ATenTensor] res,
    HUnfoldM IO TensorListUnfold res l,
    res ~ (HUnfoldMRes IO [D.ATenTensor] l)
  ) =>
  Castable (HList l) [D.ATenTensor]
  where
  cast :: forall r. HList l -> ([ATenTensor] -> IO r) -> IO r
cast HList l
xs [ATenTensor] -> IO r
f = [ATenTensor] -> IO r
f ([ATenTensor] -> IO r) -> IO [ATenTensor] -> IO r
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< HList l -> IO [ATenTensor]
go HList l
xs
    where
      go :: HList l -> IO [D.ATenTensor]
      go :: HList l -> IO [ATenTensor]
go HList l
xs = TensorListFold -> [ATenTensor] -> HList l -> IO [ATenTensor]
forall {k} {k1} (m :: k -> Type) f acc (xs :: [k1]) (res :: k).
HFoldrM m f acc xs res =>
f -> acc -> HList xs -> m res
hfoldrM TensorListFold
TensorListFold ([] :: [D.ATenTensor]) HList l
xs
  uncast :: forall r. [ATenTensor] -> (HList l -> IO r) -> IO r
uncast [ATenTensor]
xs HList l -> IO r
f = HList l -> IO r
f (HList l -> IO r) -> IO (HList l) -> IO r
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< [ATenTensor] -> IO (HList l)
go [ATenTensor]
xs
    where
      go :: [D.ATenTensor] -> IO (HList l)
      go :: [ATenTensor] -> IO (HList l)
go [ATenTensor]
xs = TensorListUnfold -> [ATenTensor] -> IO (HList l)
forall (m :: Type -> Type) f res (xs :: [Type]) a.
(HUnfoldM m f res xs, Apply f a res, res ~ HUnfoldMRes m a xs) =>
f -> a -> m (HList xs)
hunfoldrM TensorListUnfold
TensorListUnfold [ATenTensor]
xs

instance Castable (HList l) [D.ATenTensor] => Castable (HList l) (ForeignPtr ATen.TensorList) where
  cast :: forall r. HList l -> (ForeignPtr TensorList -> IO r) -> IO r
cast HList l
xs ForeignPtr TensorList -> IO r
f = do
    ts <- HList l -> ([ATenTensor] -> IO [ATenTensor]) -> IO [ATenTensor]
forall r. HList l -> ([ATenTensor] -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast HList l
xs [ATenTensor] -> IO [ATenTensor]
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO [ForeignPtr ATen.Tensor]
    cast ts f
  uncast :: forall r. ForeignPtr TensorList -> (HList l -> IO r) -> IO r
uncast ForeignPtr TensorList
xs HList l -> IO r
f = ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall r. ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs (([ATenTensor] -> IO r) -> IO r) -> ([ATenTensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \([ATenTensor]
ptrList :: [ForeignPtr ATen.Tensor]) -> do
    ts <- [ATenTensor] -> (HList l -> IO (HList l)) -> IO (HList l)
forall r. [ATenTensor] -> (HList l -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ATenTensor]
ptrList HList l -> IO (HList l)
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (HList l)
    f ts

--------------------------------------------------------------------------------
-- Move tensors
--------------------------------------------------------------------------------

-- TODO: track sparsity in tensor type
toSparse :: Tensor device dtype shape -> Tensor device dtype shape
toSparse :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor device dtype shape -> Tensor device dtype shape
toSparse Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toSparse (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

-- TODO: track sparsity in tensor type
toDense :: Tensor device dtype shape -> Tensor device dtype shape
toDense :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor device dtype shape -> Tensor device dtype shape
toDense Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toDense (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

-- -- TODO: is this a device?
-- toMKLDNN
--   :: forall device' device shape dtype
--    . Tensor device  dtype shape
--   -> Tensor device' dtype shape
-- toMKLDNN t = UnsafeMkTensor $ D.toMKLDNN (toDynamic t)

-- | move tensor to CPU
-- TODO: can this fail?
toCPU ::
  forall device shape dtype.
  Tensor device dtype shape ->
  CPUTensor dtype shape
toCPU :: forall (device :: (DeviceType, Nat)) (shape :: [Nat])
       (dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU Tensor device dtype shape
input = Tensor -> Tensor '( 'CPU, 0) dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor '( 'CPU, 0) dtype shape)
-> Tensor -> Tensor '( 'CPU, 0) dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toCPU (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
input)

-- | move tensor to the first CUDA device
-- TODO: what if this fails?
toCUDA ::
  forall device' device shape dtype.
  Tensor device dtype shape ->
  CUDATensor 0 dtype shape
toCUDA :: forall {k} (device' :: k) (device :: (DeviceType, Nat))
       (shape :: [Nat]) (dtype :: DType).
Tensor device dtype shape -> CUDATensor 0 dtype shape
toCUDA Tensor device dtype shape
t = Tensor -> Tensor '( 'CUDA, 0) dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor '( 'CUDA, 0) dtype shape)
-> Tensor -> Tensor '( 'CUDA, 0) dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toCUDA (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

-- | move tensor to the first MPS device
-- TODO: what if this fails?
toMPS ::
  forall device' device shape dtype.
  Tensor device dtype shape ->
  MPSTensor 0 dtype shape
toMPS :: forall {k} (device' :: k) (device :: (DeviceType, Nat))
       (shape :: [Nat]) (dtype :: DType).
Tensor device dtype shape -> MPSTensor 0 dtype shape
toMPS Tensor device dtype shape
t = Tensor -> Tensor '( 'MPS, 0) dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor '( 'MPS, 0) dtype shape)
-> Tensor -> Tensor '( 'MPS, 0) dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toMPS (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

-- | move tensor to device
-- TODO: what if this fails?
toDevice ::
  forall device' device dtype shape t t'.
  ( KnownDevice device',
    IsUnnamed t device dtype shape,
    Unnamed t',
    t' ~ ReplaceDevice'' t device'
  ) =>
  t ->
  t'
toDevice :: forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       (dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDevice'' t device') =>
t -> t'
toDevice = Tensor (UTDevice t') (UTDType t') (UTShape t') -> t'
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed t' device dtype shape =>
Tensor device dtype shape -> t'
fromUnnamed (Tensor (UTDevice t') (UTDType t') (UTShape t') -> t')
-> (t -> Tensor (UTDevice t') (UTDType t') (UTShape t')) -> t -> t'
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor -> Tensor (UTDevice t') (UTDType t') (UTShape t')
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor (UTDevice t') (UTDType t') (UTShape t'))
-> (t -> Tensor)
-> t
-> Tensor (UTDevice t') (UTDType t') (UTShape t')
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device') (Tensor -> Tensor) -> (t -> Tensor) -> t -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. t -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic

-- | change tensor data type
toDType ::
  forall dtype' dtype device shape t t'.
  ( KnownDType dtype',
    IsUnnamed t device dtype shape,
    Unnamed t',
    t' ~ ReplaceDType'' t dtype'
  ) =>
  t ->
  t'
toDType :: forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
toDType = Tensor (UTDevice t') (UTDType t') (UTShape t') -> t'
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed t' device dtype shape =>
Tensor device dtype shape -> t'
fromUnnamed (Tensor (UTDevice t') (UTDType t') (UTShape t') -> t')
-> (t -> Tensor (UTDevice t') (UTDType t') (UTShape t')) -> t -> t'
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor -> Tensor (UTDevice t') (UTDType t') (UTShape t')
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor (UTDevice t') (UTDType t') (UTShape t'))
-> (t -> Tensor)
-> t
-> Tensor (UTDevice t') (UTDType t') (UTShape t')
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DType -> Tensor -> Tensor
forall a. HasTypes a Tensor => DType -> a -> a
D.toType (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype') (Tensor -> Tensor) -> (t -> Tensor) -> t -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. t -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic

--------------------------------------------------------------------------------
-- Auxiliary functions for accessing tensor options as values
--------------------------------------------------------------------------------

-- | returns tensor dimension
--   uses compile-time information only
dim ::
  forall device dtype shape t.
  ( TensorOptions shape dtype device,
    IsUnnamed t device dtype shape
  ) =>
  t ->
  Int
dim :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> Int
dim t
t = [Int] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device

-- | returns tensor shape as list
--   uses compile-time information only
shape ::
  forall device dtype shape t.
  ( TensorOptions shape dtype device,
    IsUnnamed t device dtype shape
  ) =>
  t ->
  [Int]
shape :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> [Int]
shape t
_ = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device

-- | returns tensor data type
--   uses compile-time information only
dtype ::
  forall device dtype shape t.
  ( TensorOptions shape dtype device,
    IsUnnamed t device dtype shape
  ) =>
  t ->
  D.DType
dtype :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> DType
dtype t
_ = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device

-- | returns tensor device
--   uses compile-time information only
device ::
  forall device dtype shape t.
  ( TensorOptions shape dtype device,
    IsUnnamed t device dtype shape
  ) =>
  t ->
  D.Device
device :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> Device
device t
_ = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device

--------------------------------------------------------------------------------
-- Auxiliary functions for accessing tensors as values
--------------------------------------------------------------------------------

-- TODO: figure out what device, dtype, and shape we need for this
toInt ::
  Tensor device dtype shape ->
  Int
toInt :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor device dtype shape -> Int
toInt Tensor device dtype shape
t = Tensor -> Int
D.toInt (Tensor -> Int) -> Tensor -> Int
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t

toFloat :: forall device. Tensor device 'D.Float '[] -> Float
toFloat :: forall (device :: (DeviceType, Nat)).
Tensor device 'Float '[] -> Float
toFloat Tensor device 'Float '[]
t = Tensor -> Float
forall a. TensorLike a => Tensor -> a
D.asValue (Tensor -> Float)
-> (Tensor device 'Float '[] -> Tensor)
-> Tensor device 'Float '[]
-> Float
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. CPUTensor 'Float '[] -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (CPUTensor 'Float '[] -> Tensor)
-> (Tensor device 'Float '[] -> CPUTensor 'Float '[])
-> Tensor device 'Float '[]
-> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device 'Float '[] -> CPUTensor 'Float '[]
forall (device :: (DeviceType, Nat)) (shape :: [Nat])
       (dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU (Tensor device 'Float '[] -> Float)
-> Tensor device 'Float '[] -> Float
forall a b. (a -> b) -> a -> b
$ Tensor device 'Float '[]
t

toDouble :: forall device. Tensor device 'D.Double '[] -> Double
toDouble :: forall (device :: (DeviceType, Nat)).
Tensor device 'Double '[] -> Double
toDouble Tensor device 'Double '[]
t = Tensor -> Double
forall a. TensorLike a => Tensor -> a
D.asValue (Tensor -> Double)
-> (Tensor device 'Double '[] -> Tensor)
-> Tensor device 'Double '[]
-> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. CPUTensor 'Double '[] -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (CPUTensor 'Double '[] -> Tensor)
-> (Tensor device 'Double '[] -> CPUTensor 'Double '[])
-> Tensor device 'Double '[]
-> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device 'Double '[] -> CPUTensor 'Double '[]
forall (device :: (DeviceType, Nat)) (shape :: [Nat])
       (dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU (Tensor device 'Double '[] -> Double)
-> Tensor device 'Double '[] -> Double
forall a b. (a -> b) -> a -> b
$ Tensor device 'Double '[]
t

toBool :: forall device. Tensor device 'D.Bool '[] -> Bool
toBool :: forall (device :: (DeviceType, Nat)).
Tensor device 'Bool '[] -> Bool
toBool Tensor device 'Bool '[]
t = Tensor -> Bool
forall a. TensorLike a => Tensor -> a
D.asValue (Tensor -> Bool)
-> (Tensor device 'Bool '[] -> Tensor)
-> Tensor device 'Bool '[]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. CPUTensor 'Bool '[] -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (CPUTensor 'Bool '[] -> Tensor)
-> (Tensor device 'Bool '[] -> CPUTensor 'Bool '[])
-> Tensor device 'Bool '[]
-> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device 'Bool '[] -> CPUTensor 'Bool '[]
forall (device :: (DeviceType, Nat)) (shape :: [Nat])
       (dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU (Tensor device 'Bool '[] -> Bool)
-> Tensor device 'Bool '[] -> Bool
forall a b. (a -> b) -> a -> b
$ Tensor device 'Bool '[]
t

--------------------------------------------------------------------------------
-- NamedTensor
--------------------------------------------------------------------------------

type family ToDType a :: D.DType where
  ToDType Bool = 'D.Bool
  ToDType Int = 'D.Int64
  ToDType Float = 'D.Float
  ToDType Double = 'D.Double
  ToDType (f a) = ToDType a

type family ToShape a :: Shape where
  ToShape Bool = '[]
  ToShape Int = '[]
  ToShape Float = '[]
  ToShape Double = '[]
  ToShape (f a) = f ': ToShape a

type family FindDim (a :: Size) (shape :: Shape) :: Nat where
  FindDim a (a ': _) = 0
  FindDim a (b ': ax) = 1 + FindDim a ax
  FindDim a _ = TypeError (Text "Not find a type:" :<>: ShowType a :<>: Text " in the shape.")

data NamedTensor (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: Shape) where
  FromTensor :: forall device dtype shape' shape. shape ~ ToNats shape' => Tensor device dtype shape -> NamedTensor device dtype shape'

instance Unnamed (NamedTensor device dtype shape) where
  type UTShape (NamedTensor device dtype shape) = ToNats shape
  type UTDevice (NamedTensor device dtype shape) = device
  type UTDType (NamedTensor device dtype shape) = dtype
  toUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed (FromTensor Tensor device dtype shape
t) = Tensor device dtype shape
Tensor device dtype shape
t
  fromUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed = Tensor device dtype shape -> NamedTensor device dtype shape
Tensor device dtype shape -> NamedTensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape' :: Shape) (shape :: [Nat]).
(shape ~ ToNats shape') =>
Tensor device dtype shape -> NamedTensor device dtype shape'
FromTensor
  toDynamic :: NamedTensor device dtype shape -> Tensor
toDynamic (FromTensor (UnsafeMkTensor Tensor
t)) = Tensor
t

instance (KnownDevice device) => Num (NamedTensor device dtype shape) where
  + :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
(+) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a -> a
+ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
  (-) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a -> a
- NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
  * :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
(*) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a -> a
* NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
  negate :: NamedTensor device dtype shape -> NamedTensor device dtype shape
negate = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> (NamedTensor device dtype shape
    -> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a
negate (Tensor device dtype (ToNats shape)
 -> Tensor device dtype (ToNats shape))
-> (NamedTensor device dtype shape
    -> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed
  abs :: NamedTensor device dtype shape -> NamedTensor device dtype shape
abs = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> (NamedTensor device dtype shape
    -> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a
abs (Tensor device dtype (ToNats shape)
 -> Tensor device dtype (ToNats shape))
-> (NamedTensor device dtype shape
    -> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed
  signum :: NamedTensor device dtype shape -> NamedTensor device dtype shape
signum = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> (NamedTensor device dtype shape
    -> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a
signum (Tensor device dtype (ToNats shape)
 -> Tensor device dtype (ToNats shape))
-> (NamedTensor device dtype shape
    -> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed
  fromInteger :: Integer -> NamedTensor device dtype shape
fromInteger = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> (Integer -> Tensor device dtype (ToNats shape))
-> Integer
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Integer -> Tensor device dtype (ToNats shape)
forall a. Num a => Integer -> a
fromInteger

instance KnownDevice device => Fractional (NamedTensor device dtype shape) where
  NamedTensor device dtype shape
a / :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
/ NamedTensor device dtype shape
b = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Fractional a => a -> a -> a
/ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
  recip :: NamedTensor device dtype shape -> NamedTensor device dtype shape
recip = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> (NamedTensor device dtype shape
    -> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Fractional a => a -> a
recip (Tensor device dtype (ToNats shape)
 -> Tensor device dtype (ToNats shape))
-> (NamedTensor device dtype shape
    -> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed
  fromRational :: Rational -> NamedTensor device dtype shape
fromRational = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
 -> NamedTensor device dtype shape)
-> (Rational -> Tensor device dtype (ToNats shape))
-> Rational
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Rational -> Tensor device dtype (ToNats shape)
forall a. Fractional a => Rational -> a
fromRational

instance Show (NamedTensor device dtype shape) where
  show :: NamedTensor device dtype shape -> String
show = Tensor device dtype (ToNats shape) -> String
forall a. Show a => a -> String
show (Tensor device dtype (ToNats shape) -> String)
-> (NamedTensor device dtype shape
    -> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed

type family ReplaceDevice'' (tensor :: t) (device :: (D.DeviceType, Nat)) :: t where
  ReplaceDevice'' (Tensor device0 dtype shape) device1 = Tensor device1 dtype shape
  ReplaceDevice'' (NamedTensor device0 dtype shape) device1 = NamedTensor device1 dtype shape

type family ReplaceDType'' (tensor :: t) (dtype :: D.DType) :: t where
  ReplaceDType'' (Tensor device dtype0 shape) dtype1 = Tensor device dtype1 shape
  ReplaceDType'' (NamedTensor device dtype0 shape) dtype1 = NamedTensor device dtype1 shape