{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
module Torch.Random
  ( mkGenerator,
    Generator,
    randn,
    randn',
    rand,
    rand',
    randint,
    randint',
    normal,
    normal',
  )
where
import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad.IO.Class
import Control.Monad.STM
import Data.Int
import Data.Word
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import qualified Torch.Internal.Managed.Type.Generator as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Tensor
import Torch.TensorOptions
instance Show (TVar (Either (Word64, Device) (ForeignPtr ATen.Generator))) where
  show :: TVar (Either (Word64, Device) (ForeignPtr Generator)) -> String
show TVar (Either (Word64, Device) (ForeignPtr Generator))
_ = String
"_"
newtype Generator = UnsafeGenerator
  { Generator -> TVar (Either (Word64, Device) (ForeignPtr Generator))
unGenerator :: TVar (Either (Word64, Device) (ForeignPtr ATen.Generator))
  }
  deriving (Generator -> Generator -> Bool
(Generator -> Generator -> Bool)
-> (Generator -> Generator -> Bool) -> Eq Generator
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Generator -> Generator -> Bool
== :: Generator -> Generator -> Bool
$c/= :: Generator -> Generator -> Bool
/= :: Generator -> Generator -> Bool
Eq, Int -> Generator -> ShowS
[Generator] -> ShowS
Generator -> String
(Int -> Generator -> ShowS)
-> (Generator -> String)
-> ([Generator] -> ShowS)
-> Show Generator
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Generator -> ShowS
showsPrec :: Int -> Generator -> ShowS
$cshow :: Generator -> String
show :: Generator -> String
$cshowList :: [Generator] -> ShowS
showList :: [Generator] -> ShowS
Show)
mkGenerator :: Device -> Word64 -> IO Generator
mkGenerator :: Device -> Word64 -> IO Generator
mkGenerator Device
device Word64
seed =
  case Device
device of
    Device DeviceType
CPU Int16
_ -> do
      ForeignPtr Generator
genPtr <- Word64 -> IO (ForeignPtr Generator)
ATen.newCPUGenerator Word64
seed
      TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator <- Either (Word64, Device) (ForeignPtr Generator)
-> IO (TVar (Either (Word64, Device) (ForeignPtr Generator)))
forall a. a -> IO (TVar a)
newTVarIO (ForeignPtr Generator
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
      Generator -> IO Generator
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Generator -> IO Generator) -> Generator -> IO Generator
forall a b. (a -> b) -> a -> b
$ TVar (Either (Word64, Device) (ForeignPtr Generator)) -> Generator
UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator
    Device DeviceType
CUDA Int16
idx -> do
      ForeignPtr Generator
genPtr <- Word16 -> IO (ForeignPtr Generator)
ATen.newCUDAGenerator (Int16 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
idx)
      ForeignPtr Generator -> Word64 -> IO ()
ATen.generator_set_current_seed ForeignPtr Generator
genPtr Word64
seed
      TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator <- Either (Word64, Device) (ForeignPtr Generator)
-> IO (TVar (Either (Word64, Device) (ForeignPtr Generator)))
forall a. a -> IO (TVar a)
newTVarIO (ForeignPtr Generator
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
      Generator -> IO Generator
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Generator -> IO Generator) -> Generator -> IO Generator
forall a b. (a -> b) -> a -> b
$ TVar (Either (Word64, Device) (ForeignPtr Generator)) -> Generator
UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator
    Device DeviceType
MPS Int16
_ -> do
      ForeignPtr Generator
genPtr <- IO (ForeignPtr Generator)
ATen.newMPSGenerator
      ForeignPtr Generator -> Word64 -> IO ()
ATen.generator_set_current_seed ForeignPtr Generator
genPtr Word64
seed
      TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator <- Either (Word64, Device) (ForeignPtr Generator)
-> IO (TVar (Either (Word64, Device) (ForeignPtr Generator)))
forall a. a -> IO (TVar a)
newTVarIO (ForeignPtr Generator
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
      Generator -> IO Generator
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Generator -> IO Generator) -> Generator -> IO Generator
forall a b. (a -> b) -> a -> b
$ TVar (Either (Word64, Device) (ForeignPtr Generator)) -> Generator
UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
genenerator
type RandomGenFunc = ForeignPtr ATen.IntArray -> ForeignPtr ATen.Generator -> ForeignPtr ATen.TensorOptions -> IO (ForeignPtr ATen.Tensor)
generatorFactory :: RandomGenFunc -> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory :: RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory RandomGenFunc
func [Int]
size TensorOptions
options (UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
generator) =
  IO (Tensor, Generator) -> (Tensor, Generator)
forall a. IO a -> a
unsafePerformIO (IO (Tensor, Generator) -> (Tensor, Generator))
-> IO (Tensor, Generator) -> (Tensor, Generator)
forall a b. (a -> b) -> a -> b
$ do
    Either (Word64, Device) (ForeignPtr Generator)
mGenerator <- STM (Either (Word64, Device) (ForeignPtr Generator))
-> IO (Either (Word64, Device) (ForeignPtr Generator))
forall a. STM a -> IO a
atomically (STM (Either (Word64, Device) (ForeignPtr Generator))
 -> IO (Either (Word64, Device) (ForeignPtr Generator)))
-> STM (Either (Word64, Device) (ForeignPtr Generator))
-> IO (Either (Word64, Device) (ForeignPtr Generator))
forall a b. (a -> b) -> a -> b
$ do
      Either (Word64, Device) (ForeignPtr Generator)
v <- TVar (Either (Word64, Device) (ForeignPtr Generator))
-> STM (Either (Word64, Device) (ForeignPtr Generator))
forall a. TVar a -> STM a
readTVar TVar (Either (Word64, Device) (ForeignPtr Generator))
generator
      case Either (Word64, Device) (ForeignPtr Generator)
v of
        Right ForeignPtr Generator
v' -> do
          let device :: Device
device =
                if ForeignPtr Generator -> Bool
generatorIsCuda ForeignPtr Generator
v'
                  then Device {deviceType :: DeviceType
deviceType = DeviceType
CUDA, deviceIndex :: Int16
deviceIndex = Int -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int16) -> Int -> Int16
forall a b. (a -> b) -> a -> b
$ ForeignPtr Generator -> Int
generatorDevice ForeignPtr Generator
v'}
                  else
                    if ForeignPtr Generator -> Bool
generatorIsMps ForeignPtr Generator
v'
                    then Device {deviceType :: DeviceType
deviceType = DeviceType
MPS, deviceIndex :: Int16
deviceIndex = Int16
0}
                    else Device {deviceType :: DeviceType
deviceType = DeviceType
CPU, deviceIndex :: Int16
deviceIndex = Int16
0}
              seed :: Word64
seed = ForeignPtr Generator -> Word64
generatorSeed ForeignPtr Generator
v'
          TVar (Either (Word64, Device) (ForeignPtr Generator))
-> Either (Word64, Device) (ForeignPtr Generator) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Either (Word64, Device) (ForeignPtr Generator))
generator (Either (Word64, Device) (ForeignPtr Generator) -> STM ())
-> Either (Word64, Device) (ForeignPtr Generator) -> STM ()
forall a b. (a -> b) -> a -> b
$ Word64
seed Word64
-> Either (Word64, Device) (ForeignPtr Generator)
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> b -> b
`seq` Device -> DeviceType
deviceType Device
device DeviceType
-> Either (Word64, Device) (ForeignPtr Generator)
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> b -> b
`seq` Device -> Int16
deviceIndex Device
device Int16
-> Either (Word64, Device) (ForeignPtr Generator)
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> b -> b
`seq` (Word64, Device) -> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> Either a b
Left (Word64
seed, Device
device)
          Either (Word64, Device) (ForeignPtr Generator)
-> STM (Either (Word64, Device) (ForeignPtr Generator))
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Word64, Device) (ForeignPtr Generator)
 -> STM (Either (Word64, Device) (ForeignPtr Generator)))
-> Either (Word64, Device) (ForeignPtr Generator)
-> STM (Either (Word64, Device) (ForeignPtr Generator))
forall a b. (a -> b) -> a -> b
$ ForeignPtr Generator
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. b -> Either a b
Right ForeignPtr Generator
v'
        Left (Word64, Device)
v -> Either (Word64, Device) (ForeignPtr Generator)
-> STM (Either (Word64, Device) (ForeignPtr Generator))
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Word64, Device) -> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> Either a b
Left (Word64, Device)
v)
    ForeignPtr Generator
genPtr <- case Either (Word64, Device) (ForeignPtr Generator)
mGenerator of
      Right ForeignPtr Generator
gen -> ForeignPtr Generator -> IO (ForeignPtr Generator)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Generator
gen
      Left (Word64
seed, Device
device) -> case Device
device of
        Device DeviceType
CPU Int16
_ -> Word64 -> IO (ForeignPtr Generator)
ATen.newCPUGenerator Word64
seed
        Device DeviceType
CUDA Int16
idx -> do
          ForeignPtr Generator
gen <- Word16 -> IO (ForeignPtr Generator)
ATen.newCUDAGenerator (Int16 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
idx)
          ForeignPtr Generator -> Word64 -> IO ()
ATen.generator_set_current_seed ForeignPtr Generator
gen Word64
seed
          ForeignPtr Generator -> IO (ForeignPtr Generator)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Generator
gen
        Device DeviceType
MPS Int16
_ -> do
          ForeignPtr Generator
gen <- IO (ForeignPtr Generator)
ATen.newMPSGenerator
          ForeignPtr Generator -> Word64 -> IO ()
ATen.generator_set_current_seed ForeignPtr Generator
gen Word64
seed
          ForeignPtr Generator -> IO (ForeignPtr Generator)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Generator
gen
    Tensor
tensor <- RandomGenFunc
-> [Int] -> ForeignPtr Generator -> TensorOptions -> IO Tensor
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 RandomGenFunc
func [Int]
size ForeignPtr Generator
genPtr TensorOptions
options
    TVar (Either (Word64, Device) (ForeignPtr Generator))
nextGenenerator <- Either (Word64, Device) (ForeignPtr Generator)
-> IO (TVar (Either (Word64, Device) (ForeignPtr Generator)))
forall a. a -> IO (TVar a)
newTVarIO (ForeignPtr Generator
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. b -> Either a b
Right ForeignPtr Generator
genPtr)
    (Tensor, Generator) -> IO (Tensor, Generator)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor
tensor, TVar (Either (Word64, Device) (ForeignPtr Generator)) -> Generator
UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
nextGenenerator)
  where
    generatorIsCpu :: ForeignPtr ATen.Generator -> Bool
    generatorIsCpu :: ForeignPtr Generator -> Bool
generatorIsCpu ForeignPtr Generator
gen = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO CBool)
-> ForeignPtr Generator -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_cpu ForeignPtr Generator
gen
    generatorIsCuda :: ForeignPtr ATen.Generator -> Bool
    generatorIsCuda :: ForeignPtr Generator -> Bool
generatorIsCuda ForeignPtr Generator
gen = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO CBool)
-> ForeignPtr Generator -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_cuda ForeignPtr Generator
gen
    generatorIsMps :: ForeignPtr ATen.Generator -> Bool
    generatorIsMps :: ForeignPtr Generator -> Bool
generatorIsMps ForeignPtr Generator
gen = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO CBool)
-> ForeignPtr Generator -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_mps ForeignPtr Generator
gen
    generatorDevice :: ForeignPtr ATen.Generator -> Int
    generatorDevice :: ForeignPtr Generator -> Int
generatorDevice ForeignPtr Generator
gen = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO Int64)
-> ForeignPtr Generator -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO Int64
ATen.generator_get_device ForeignPtr Generator
gen
    generatorSeed :: ForeignPtr ATen.Generator -> Word64
    generatorSeed :: ForeignPtr Generator -> Word64
generatorSeed ForeignPtr Generator
gen = IO Word64 -> Word64
forall a. IO a -> a
unsafePerformIO (IO Word64 -> Word64) -> IO Word64 -> Word64
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO Word64)
-> ForeignPtr Generator -> IO Word64
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO Word64
ATen.generator_current_seed ForeignPtr Generator
gen
randn ::
  
  [Int] ->
  
  TensorOptions ->
  
  Generator ->
  
  (Tensor, Generator)
randn :: [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
randn = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory RandomGenFunc
LibTorch.randn_lGo
randn' ::
  
  [Int] ->
  
  Generator ->
  
  (Tensor, Generator)
randn' :: [Int] -> Generator -> (Tensor, Generator)
randn' [Int]
size = [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
randn [Int]
size TensorOptions
defaultOpts
rand ::
  
  [Int] ->
  
  TensorOptions ->
  
  Generator ->
  
  (Tensor, Generator)
rand :: [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
rand = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory RandomGenFunc
LibTorch.rand_lGo
rand' ::
  
  [Int] ->
  
  Generator ->
  
  (Tensor, Generator)
rand' :: [Int] -> Generator -> (Tensor, Generator)
rand' [Int]
size = [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
rand [Int]
size TensorOptions
defaultOpts
randint ::
  
  Int ->
  
  Int ->
  
  [Int] ->
  
  TensorOptions ->
  
  Generator ->
  
  (Tensor, Generator)
randint :: Int
-> Int
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
randint Int
low Int
high = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory (Int64 -> Int64 -> RandomGenFunc
LibTorch.randint_lllGo (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
low) (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
high))
randint' ::
  
  Int ->
  
  Int ->
  
  [Int] ->
  
  Generator ->
  
  (Tensor, Generator)
randint' :: Int -> Int -> [Int] -> Generator -> (Tensor, Generator)
randint' Int
low Int
high [Int]
size = Int
-> Int
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
randint Int
low Int
high [Int]
size TensorOptions
defaultOpts
normal ::
  
  Double ->
  
  Double ->
  
  [Int] ->
  
  TensorOptions ->
  
  Generator ->
  
  (Tensor, Generator)
normal :: Double
-> Double
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
normal Double
mean Double
std = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory (CDouble -> CDouble -> RandomGenFunc
LibTorch.normal_ddlGo (Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
mean) (Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
std))
normal' ::
  
  Double ->
  
  Double ->
  
  [Int] ->
  
  Generator ->
  
  (Tensor, Generator)
normal' :: Double -> Double -> [Int] -> Generator -> (Tensor, Generator)
normal' Double
mean Double
std [Int]
size = Double
-> Double
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
normal Double
mean Double
std [Int]
size TensorOptions
defaultOpts