{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.Dimname where

import Data.String
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Type.Dimname as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Symbol as ATen
import qualified Torch.Internal.Type as ATen

newtype Dimname = Dimname (ForeignPtr ATen.Dimname)

instance IsString Dimname where
  fromString :: String -> Dimname
fromString String
str = IO Dimname -> Dimname
forall a. IO a -> a
unsafePerformIO (IO Dimname -> Dimname) -> IO Dimname -> Dimname
forall a b. (a -> b) -> a -> b
$ do
    str' <- String -> IO (ForeignPtr StdString)
ATen.newStdString_s String
str
    symbol <- ATen.dimname_s str'
    dimname <- ATen.fromSymbol_s symbol
    return $ Dimname dimname

instance Castable Dimname (ForeignPtr ATen.Dimname) where
  cast :: forall r. Dimname -> (ForeignPtr Dimname -> IO r) -> IO r
cast (Dimname ForeignPtr Dimname
dname) ForeignPtr Dimname -> IO r
f = ForeignPtr Dimname -> IO r
f ForeignPtr Dimname
dname
  uncast :: forall r. ForeignPtr Dimname -> (Dimname -> IO r) -> IO r
uncast ForeignPtr Dimname
dname Dimname -> IO r
f = Dimname -> IO r
f (Dimname -> IO r) -> Dimname -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr Dimname -> Dimname
Dimname ForeignPtr Dimname
dname

instance Castable [Dimname] (ForeignPtr ATen.DimnameList) where
  cast :: forall r. [Dimname] -> (ForeignPtr DimnameList -> IO r) -> IO r
cast [Dimname]
xs ForeignPtr DimnameList -> IO r
f = do
    ptr_list <- (Dimname -> IO (ForeignPtr Dimname))
-> [Dimname] -> IO [ForeignPtr Dimname]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Dimname
x -> Dimname
-> (ForeignPtr Dimname -> IO (ForeignPtr Dimname))
-> IO (ForeignPtr Dimname)
forall r. Dimname -> (ForeignPtr Dimname -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Dimname
x ForeignPtr Dimname -> IO (ForeignPtr Dimname)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Dimname)) [Dimname]
xs
    cast (map Dimname ptr_list) f
  uncast :: forall r. ForeignPtr DimnameList -> ([Dimname] -> IO r) -> IO r
uncast ForeignPtr DimnameList
xs [Dimname] -> IO r
f = ForeignPtr DimnameList -> ([Dimname] -> IO r) -> IO r
forall r. ForeignPtr DimnameList -> ([Dimname] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr DimnameList
xs (([Dimname] -> IO r) -> IO r) -> ([Dimname] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[Dimname]
ptr_list -> do
    dname_list <- (Dimname -> IO Dimname) -> [Dimname] -> IO [Dimname]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((\(ForeignPtr Dimname
x :: ForeignPtr ATen.Dimname) -> ForeignPtr Dimname -> (Dimname -> IO Dimname) -> IO Dimname
forall r. ForeignPtr Dimname -> (Dimname -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Dimname
x Dimname -> IO Dimname
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) (ForeignPtr Dimname -> IO Dimname)
-> (Dimname -> ForeignPtr Dimname) -> Dimname -> IO Dimname
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(Dimname ForeignPtr Dimname
dname) -> ForeignPtr Dimname
dname)) [Dimname]
ptr_list
    f dname_list