{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Torch.Internal.Managed.Cast where
import Control.Exception.Safe (throwIO)
import Foreign.ForeignPtr
import Foreign.C.Types
import Data.Int
import Control.Monad
import Torch.Internal.Class
import Torch.Internal.Cast
import Torch.Internal.Type
import Torch.Internal.Managed.Type.IntArray
import Torch.Internal.Managed.Type.TensorList
import Torch.Internal.Managed.Type.C10List
import Torch.Internal.Managed.Type.IValueList
import Torch.Internal.Managed.Type.C10Tuple
import Torch.Internal.Managed.Type.C10Dict
import Torch.Internal.Managed.Type.StdVector
instance Castable Int (ForeignPtr IntArray) where
cast :: forall r. Int -> (ForeignPtr IntArray -> IO r) -> IO r
cast Int
xs ForeignPtr IntArray -> IO r
f = do
arr <- IO (ForeignPtr IntArray)
newIntArray
intArray_push_back_l arr $ fromIntegral xs
f arr
uncast :: forall r. ForeignPtr IntArray -> (Int -> IO r) -> IO r
uncast ForeignPtr IntArray
xs Int -> IO r
f = do
v <- ForeignPtr IntArray -> CSize -> IO Int64
intArray_at_s ForeignPtr IntArray
xs CSize
0
f (fromIntegral v)
instance Castable [Int] (ForeignPtr IntArray) where
cast :: forall r. [Int] -> (ForeignPtr IntArray -> IO r) -> IO r
cast [Int]
xs ForeignPtr IntArray -> IO r
f = do
arr <- IO (ForeignPtr IntArray)
newIntArray
intArray_fromList arr (map fromIntegral xs)
f arr
uncast :: forall r. ForeignPtr IntArray -> ([Int] -> IO r) -> IO r
uncast ForeignPtr IntArray
xs [Int] -> IO r
f = do
xs <- ForeignPtr IntArray -> IO [Int64]
intArray_toList ForeignPtr IntArray
xs
f (map fromIntegral xs)
instance Castable [Double] (ForeignPtr (StdVector CDouble)) where
cast :: forall r.
[Double] -> (ForeignPtr (StdVector CDouble) -> IO r) -> IO r
cast [Double]
xs ForeignPtr (StdVector CDouble) -> IO r
f = do
arr <- IO (ForeignPtr (StdVector CDouble))
newStdVectorDouble
forM_ xs $ (stdVectorDouble_push_back arr) . realToFrac
f arr
uncast :: forall r.
ForeignPtr (StdVector CDouble) -> ([Double] -> IO r) -> IO r
uncast ForeignPtr (StdVector CDouble)
xs [Double] -> IO r
f = do
len <- ForeignPtr (StdVector CDouble) -> IO CSize
stdVectorDouble_size ForeignPtr (StdVector CDouble)
xs
if len == 0
then f []
else f =<< mapM (\CSize
i -> ForeignPtr (StdVector CDouble) -> CSize -> IO CDouble
stdVectorDouble_at ForeignPtr (StdVector CDouble)
xs CSize
i IO CDouble -> (CDouble -> IO Double) -> IO Double
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Double -> IO Double
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> IO Double)
-> (CDouble -> Double) -> CDouble -> IO Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CDouble -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac) [0..(len - 1)]
instance Castable [ForeignPtr Tensor] (ForeignPtr TensorList) where
cast :: forall r.
[ForeignPtr Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
cast [ForeignPtr Tensor]
xs ForeignPtr TensorList -> IO r
f = do
l <- IO (ForeignPtr TensorList)
newTensorList
forM_ xs $ (tensorList_push_back_t l)
f l
uncast :: forall r.
ForeignPtr TensorList -> ([ForeignPtr Tensor] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [ForeignPtr Tensor] -> IO r
f = do
len <- ForeignPtr TensorList -> IO CSize
tensorList_size ForeignPtr TensorList
xs
f =<< mapM (tensorList_at_s xs) [0..(len - 1)]
instance Castable [ForeignPtr Tensor] (ForeignPtr (C10List Tensor)) where
cast :: forall r.
[ForeignPtr Tensor]
-> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
cast [ForeignPtr Tensor]
xs ForeignPtr (C10List Tensor) -> IO r
f = do
l <- IO (ForeignPtr (C10List Tensor))
newC10ListTensor
forM_ xs $ (c10ListTensor_push_back l)
f l
uncast :: forall r.
ForeignPtr (C10List Tensor)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs [ForeignPtr Tensor] -> IO r
f = do
len <- ForeignPtr (C10List Tensor) -> IO CSize
c10ListTensor_size ForeignPtr (C10List Tensor)
xs
f =<< mapM (c10ListTensor_at xs) [0..(len - 1)]
instance Castable [ForeignPtr Tensor] (ForeignPtr (C10List (C10Optional Tensor))) where
cast :: forall r.
[ForeignPtr Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
cast [ForeignPtr Tensor]
xs ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f = do
l <- IO (ForeignPtr (C10List (C10Optional Tensor)))
newC10ListOptionalTensor
forM_ xs $ (c10ListOptionalTensor_push_back l)
f l
uncast :: forall r.
ForeignPtr (C10List (C10Optional Tensor))
-> ([ForeignPtr Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs [ForeignPtr Tensor] -> IO r
f = do
len <- ForeignPtr (C10List (C10Optional Tensor)) -> IO CSize
c10ListOptionalTensor_size ForeignPtr (C10List (C10Optional Tensor))
xs
f =<< mapM (c10ListOptionalTensor_at xs) [0..(len - 1)]
instance Castable [CDouble] (ForeignPtr (C10List CDouble)) where
cast :: forall r.
[CDouble] -> (ForeignPtr (C10List CDouble) -> IO r) -> IO r
cast [CDouble]
xs ForeignPtr (C10List CDouble) -> IO r
f = do
l <- IO (ForeignPtr (C10List CDouble))
newC10ListDouble
forM_ xs $ (c10ListDouble_push_back l)
f l
uncast :: forall r.
ForeignPtr (C10List CDouble) -> ([CDouble] -> IO r) -> IO r
uncast ForeignPtr (C10List CDouble)
xs [CDouble] -> IO r
f = do
len <- ForeignPtr (C10List CDouble) -> IO CSize
c10ListDouble_size ForeignPtr (C10List CDouble)
xs
f =<< mapM (c10ListDouble_at xs) [0..(len - 1)]
instance Castable [Int64] (ForeignPtr (C10List Int64)) where
cast :: forall r. [Int64] -> (ForeignPtr (C10List Int64) -> IO r) -> IO r
cast [Int64]
xs ForeignPtr (C10List Int64) -> IO r
f = do
l <- IO (ForeignPtr (C10List Int64))
newC10ListInt
forM_ xs $ (c10ListInt_push_back l)
f l
uncast :: forall r. ForeignPtr (C10List Int64) -> ([Int64] -> IO r) -> IO r
uncast ForeignPtr (C10List Int64)
xs [Int64] -> IO r
f = do
len <- ForeignPtr (C10List Int64) -> IO CSize
c10ListInt_size ForeignPtr (C10List Int64)
xs
f =<< mapM (c10ListInt_at xs) [0..(len - 1)]
instance Castable [CBool] (ForeignPtr (C10List CBool)) where
cast :: forall r. [CBool] -> (ForeignPtr (C10List CBool) -> IO r) -> IO r
cast [CBool]
xs ForeignPtr (C10List CBool) -> IO r
f = do
l <- IO (ForeignPtr (C10List CBool))
newC10ListBool
forM_ xs $ (c10ListBool_push_back l)
f l
uncast :: forall r. ForeignPtr (C10List CBool) -> ([CBool] -> IO r) -> IO r
uncast ForeignPtr (C10List CBool)
xs [CBool] -> IO r
f = do
len <- ForeignPtr (C10List CBool) -> IO CSize
c10ListBool_size ForeignPtr (C10List CBool)
xs
f =<< mapM (c10ListBool_at xs) [0..(len - 1)]
instance Castable [ForeignPtr IValue] (ForeignPtr IValueList) where
cast :: forall r.
[ForeignPtr IValue] -> (ForeignPtr IValueList -> IO r) -> IO r
cast [ForeignPtr IValue]
xs ForeignPtr IValueList -> IO r
f = do
l <- IO (ForeignPtr IValueList)
newIValueList
forM_ xs $ (ivalueList_push_back l)
f l
uncast :: forall r.
ForeignPtr IValueList -> ([ForeignPtr IValue] -> IO r) -> IO r
uncast ForeignPtr IValueList
xs [ForeignPtr IValue] -> IO r
f = do
len <- ForeignPtr IValueList -> IO CSize
ivalueList_size ForeignPtr IValueList
xs
f =<< mapM (ivalueList_at xs) [0..(len - 1)]
instance Castable [ForeignPtr IValue] (ForeignPtr (C10Ptr IVTuple)) where
cast :: forall r.
[ForeignPtr IValue]
-> (ForeignPtr (C10Ptr IVTuple) -> IO r) -> IO r
cast [ForeignPtr IValue]
xs ForeignPtr (C10Ptr IVTuple) -> IO r
f = do
[ForeignPtr IValue] -> (ForeignPtr IValueList -> IO r) -> IO r
forall r.
[ForeignPtr IValue] -> (ForeignPtr IValueList -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr IValue]
xs ((ForeignPtr IValueList -> IO r) -> IO r)
-> (ForeignPtr IValueList -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \ForeignPtr IValueList
ivalueList -> do
ForeignPtr (C10Ptr IVTuple) -> IO r
f (ForeignPtr (C10Ptr IVTuple) -> IO r)
-> IO (ForeignPtr (C10Ptr IVTuple)) -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr IValueList -> IO (ForeignPtr (C10Ptr IVTuple))
newC10Tuple_tuple ForeignPtr IValueList
ivalueList
uncast :: forall r.
ForeignPtr (C10Ptr IVTuple)
-> ([ForeignPtr IValue] -> IO r) -> IO r
uncast ForeignPtr (C10Ptr IVTuple)
xs [ForeignPtr IValue] -> IO r
f = do
len <- ForeignPtr (C10Ptr IVTuple) -> IO CSize
c10Tuple_size ForeignPtr (C10Ptr IVTuple)
xs
f =<< mapM (c10Tuple_at xs) [0..(len - 1)]
instance Castable [ForeignPtr IValue] (ForeignPtr (C10List IValue)) where
cast :: forall r.
[ForeignPtr IValue]
-> (ForeignPtr (C10List IValue) -> IO r) -> IO r
cast [] ForeignPtr (C10List IValue) -> IO r
_ = IOError -> IO r
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO r) -> IOError -> IO r
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"[ForeignPtr IValue]'s length must be one or more."
cast [ForeignPtr IValue]
xs ForeignPtr (C10List IValue) -> IO r
f = do
l <- ForeignPtr IValue -> IO (ForeignPtr (C10List IValue))
newC10ListIValue ([ForeignPtr IValue] -> ForeignPtr IValue
forall a. HasCallStack => [a] -> a
head [ForeignPtr IValue]
xs)
forM_ xs $ (c10ListIValue_push_back l)
f l
uncast :: forall r.
ForeignPtr (C10List IValue)
-> ([ForeignPtr IValue] -> IO r) -> IO r
uncast ForeignPtr (C10List IValue)
xs [ForeignPtr IValue] -> IO r
f = do
len <- ForeignPtr (C10List IValue) -> IO CSize
c10ListIValue_size ForeignPtr (C10List IValue)
xs
f =<< mapM (c10ListIValue_at xs) [0..(len - 1)]
instance Castable [(ForeignPtr IValue,ForeignPtr IValue)] (ForeignPtr (C10Dict '(IValue,IValue))) where
cast :: forall r.
[(ForeignPtr IValue, ForeignPtr IValue)]
-> (ForeignPtr (C10Dict '(IValue, IValue)) -> IO r) -> IO r
cast [] ForeignPtr (C10Dict '(IValue, IValue)) -> IO r
_ = IOError -> IO r
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO r) -> IOError -> IO r
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"[(ForeignPtr IValue,ForeignPtr IValue)]'s length must be one or more."
cast [(ForeignPtr IValue, ForeignPtr IValue)]
xs ForeignPtr (C10Dict '(IValue, IValue)) -> IO r
f = do
let (ForeignPtr IValue
k,ForeignPtr IValue
v) = ([(ForeignPtr IValue, ForeignPtr IValue)]
-> (ForeignPtr IValue, ForeignPtr IValue)
forall a. HasCallStack => [a] -> a
head [(ForeignPtr IValue, ForeignPtr IValue)]
xs)
l <- ForeignPtr IValue
-> ForeignPtr IValue -> IO (ForeignPtr (C10Dict '(IValue, IValue)))
newC10Dict ForeignPtr IValue
k ForeignPtr IValue
v
forM_ xs $ \(ForeignPtr IValue
k,ForeignPtr IValue
v) -> (ForeignPtr (C10Dict '(IValue, IValue))
-> ForeignPtr IValue -> ForeignPtr IValue -> IO ()
c10Dict_insert ForeignPtr (C10Dict '(IValue, IValue))
l ForeignPtr IValue
k ForeignPtr IValue
v)
f l
uncast :: forall r.
ForeignPtr (C10Dict '(IValue, IValue))
-> ([(ForeignPtr IValue, ForeignPtr IValue)] -> IO r) -> IO r
uncast ForeignPtr (C10Dict '(IValue, IValue))
xs [(ForeignPtr IValue, ForeignPtr IValue)] -> IO r
f = [(ForeignPtr IValue, ForeignPtr IValue)] -> IO r
f ([(ForeignPtr IValue, ForeignPtr IValue)] -> IO r)
-> IO [(ForeignPtr IValue, ForeignPtr IValue)] -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ForeignPtr (C10Dict '(IValue, IValue))
-> IO [(ForeignPtr IValue, ForeignPtr IValue)]
c10Dict_toList ForeignPtr (C10Dict '(IValue, IValue))
xs