{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}

module Torch.Random
  ( mkGenerator,
    Generator,
    randn,
    randn',
    rand,
    rand',
    randint,
    randint',
    normal,
    normal',
  )
where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad.IO.Class
import Control.Monad.STM
import Data.Int
import Data.Word
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import qualified Torch.Internal.Managed.Type.Generator as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Tensor
import Torch.TensorOptions

instance Show (TVar (Either (Word64, Device) (ForeignPtr ATen.Generator))) where
  show :: TVar (Either (Word64, Device) (ForeignPtr Generator)) -> String
show TVar (Either (Word64, Device) (ForeignPtr Generator))
_ = String
"_"

newtype Generator = UnsafeGenerator
  { Generator -> TVar (Either (Word64, Device) (ForeignPtr Generator))
unGenerator :: TVar (Either (Word64, Device) (ForeignPtr ATen.Generator))
  }
  deriving (Generator -> Generator -> Bool
(Generator -> Generator -> Bool)
-> (Generator -> Generator -> Bool) -> Eq Generator
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Generator -> Generator -> Bool
== :: Generator -> Generator -> Bool
$c/= :: Generator -> Generator -> Bool
/= :: Generator -> Generator -> Bool
Eq, Int -> Generator -> ShowS
[Generator] -> ShowS
Generator -> String
(Int -> Generator -> ShowS)
-> (Generator -> String)
-> ([Generator] -> ShowS)
-> Show Generator
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Generator -> ShowS
showsPrec :: Int -> Generator -> ShowS
$cshow :: Generator -> String
show :: Generator -> String
$cshowList :: [Generator] -> ShowS
showList :: [Generator] -> ShowS
Show)

mkGenerator :: Device -> Word64 -> IO Generator
mkGenerator :: Device -> Word64 -> IO Generator
mkGenerator Device
device Word64
seed =
  case Device
device of
    Device DeviceType
CPU Int16
_ -> do
      genPtr <- Word64 -> IO (ForeignPtr Generator)
ATen.newCPUGenerator Word64
seed
      genenerator <- newTVarIO (Right genPtr)
      return $ UnsafeGenerator genenerator
    Device DeviceType
CUDA Int16
idx -> do
      genPtr <- Word16 -> IO (ForeignPtr Generator)
ATen.newCUDAGenerator (Int16 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
idx)
      ATen.generator_set_current_seed genPtr seed
      genenerator <- newTVarIO (Right genPtr)
      return $ UnsafeGenerator genenerator
    Device DeviceType
MPS Int16
_ -> do
      genPtr <- IO (ForeignPtr Generator)
ATen.newMPSGenerator
      ATen.generator_set_current_seed genPtr seed
      genenerator <- newTVarIO (Right genPtr)
      return $ UnsafeGenerator genenerator

type RandomGenFunc = ForeignPtr ATen.IntArray -> ForeignPtr ATen.Generator -> ForeignPtr ATen.TensorOptions -> IO (ForeignPtr ATen.Tensor)

generatorFactory :: RandomGenFunc -> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory :: RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory RandomGenFunc
func [Int]
size TensorOptions
options (UnsafeGenerator TVar (Either (Word64, Device) (ForeignPtr Generator))
generator) =
  IO (Tensor, Generator) -> (Tensor, Generator)
forall a. IO a -> a
unsafePerformIO (IO (Tensor, Generator) -> (Tensor, Generator))
-> IO (Tensor, Generator) -> (Tensor, Generator)
forall a b. (a -> b) -> a -> b
$ do
    mGenerator <- STM (Either (Word64, Device) (ForeignPtr Generator))
-> IO (Either (Word64, Device) (ForeignPtr Generator))
forall a. STM a -> IO a
atomically (STM (Either (Word64, Device) (ForeignPtr Generator))
 -> IO (Either (Word64, Device) (ForeignPtr Generator)))
-> STM (Either (Word64, Device) (ForeignPtr Generator))
-> IO (Either (Word64, Device) (ForeignPtr Generator))
forall a b. (a -> b) -> a -> b
$ do
      v <- TVar (Either (Word64, Device) (ForeignPtr Generator))
-> STM (Either (Word64, Device) (ForeignPtr Generator))
forall a. TVar a -> STM a
readTVar TVar (Either (Word64, Device) (ForeignPtr Generator))
generator
      case v of
        Right ForeignPtr Generator
v' -> do
          let device :: Device
device =
                if ForeignPtr Generator -> Bool
generatorIsCuda ForeignPtr Generator
v'
                  then Device {deviceType :: DeviceType
deviceType = DeviceType
CUDA, deviceIndex :: Int16
deviceIndex = Int -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int16) -> Int -> Int16
forall a b. (a -> b) -> a -> b
$ ForeignPtr Generator -> Int
generatorDevice ForeignPtr Generator
v'}
                  else
                    if ForeignPtr Generator -> Bool
generatorIsMps ForeignPtr Generator
v'
                    then Device {deviceType :: DeviceType
deviceType = DeviceType
MPS, deviceIndex :: Int16
deviceIndex = Int16
0}
                    else Device {deviceType :: DeviceType
deviceType = DeviceType
CPU, deviceIndex :: Int16
deviceIndex = Int16
0}
              seed :: Word64
seed = ForeignPtr Generator -> Word64
generatorSeed ForeignPtr Generator
v'
          TVar (Either (Word64, Device) (ForeignPtr Generator))
-> Either (Word64, Device) (ForeignPtr Generator) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Either (Word64, Device) (ForeignPtr Generator))
generator (Either (Word64, Device) (ForeignPtr Generator) -> STM ())
-> Either (Word64, Device) (ForeignPtr Generator) -> STM ()
forall a b. (a -> b) -> a -> b
$ Word64
seed Word64
-> Either (Word64, Device) (ForeignPtr Generator)
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> b -> b
`seq` Device -> DeviceType
deviceType Device
device DeviceType
-> Either (Word64, Device) (ForeignPtr Generator)
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> b -> b
`seq` Device -> Int16
deviceIndex Device
device Int16
-> Either (Word64, Device) (ForeignPtr Generator)
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> b -> b
`seq` (Word64, Device) -> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> Either a b
Left (Word64
seed, Device
device)
          Either (Word64, Device) (ForeignPtr Generator)
-> STM (Either (Word64, Device) (ForeignPtr Generator))
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Word64, Device) (ForeignPtr Generator)
 -> STM (Either (Word64, Device) (ForeignPtr Generator)))
-> Either (Word64, Device) (ForeignPtr Generator)
-> STM (Either (Word64, Device) (ForeignPtr Generator))
forall a b. (a -> b) -> a -> b
$ ForeignPtr Generator
-> Either (Word64, Device) (ForeignPtr Generator)
forall a b. b -> Either a b
Right ForeignPtr Generator
v'
        Left (Word64, Device)
v -> Either (Word64, Device) (ForeignPtr Generator)
-> STM (Either (Word64, Device) (ForeignPtr Generator))
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Word64, Device) -> Either (Word64, Device) (ForeignPtr Generator)
forall a b. a -> Either a b
Left (Word64, Device)
v)
    genPtr <- case mGenerator of
      Right ForeignPtr Generator
gen -> ForeignPtr Generator -> IO (ForeignPtr Generator)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ForeignPtr Generator
gen
      Left (Word64
seed, Device
device) -> case Device
device of
        Device DeviceType
CPU Int16
_ -> Word64 -> IO (ForeignPtr Generator)
ATen.newCPUGenerator Word64
seed
        Device DeviceType
CUDA Int16
idx -> do
          gen <- Word16 -> IO (ForeignPtr Generator)
ATen.newCUDAGenerator (Int16 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
idx)
          ATen.generator_set_current_seed gen seed
          return gen
        Device DeviceType
MPS Int16
_ -> do
          gen <- IO (ForeignPtr Generator)
ATen.newMPSGenerator
          ATen.generator_set_current_seed gen seed
          return gen
    tensor <- cast3 func size genPtr options
    nextGenenerator <- newTVarIO (Right genPtr)
    return (tensor, UnsafeGenerator nextGenenerator)
  where
    generatorIsCpu :: ForeignPtr ATen.Generator -> Bool
    generatorIsCpu :: ForeignPtr Generator -> Bool
generatorIsCpu ForeignPtr Generator
gen = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO CBool)
-> ForeignPtr Generator -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_cpu ForeignPtr Generator
gen

    generatorIsCuda :: ForeignPtr ATen.Generator -> Bool
    generatorIsCuda :: ForeignPtr Generator -> Bool
generatorIsCuda ForeignPtr Generator
gen = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO CBool)
-> ForeignPtr Generator -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_cuda ForeignPtr Generator
gen

    generatorIsMps :: ForeignPtr ATen.Generator -> Bool
    generatorIsMps :: ForeignPtr Generator -> Bool
generatorIsMps ForeignPtr Generator
gen = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO CBool)
-> ForeignPtr Generator -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO CBool
ATen.generator_is_mps ForeignPtr Generator
gen

    generatorDevice :: ForeignPtr ATen.Generator -> Int
    generatorDevice :: ForeignPtr Generator -> Int
generatorDevice ForeignPtr Generator
gen = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO Int64)
-> ForeignPtr Generator -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO Int64
ATen.generator_get_device ForeignPtr Generator
gen

    generatorSeed :: ForeignPtr ATen.Generator -> Word64
    generatorSeed :: ForeignPtr Generator -> Word64
generatorSeed ForeignPtr Generator
gen = IO Word64 -> Word64
forall a. IO a -> a
unsafePerformIO (IO Word64 -> Word64) -> IO Word64 -> Word64
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Generator -> IO Word64)
-> ForeignPtr Generator -> IO Word64
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Generator -> IO Word64
ATen.generator_current_seed ForeignPtr Generator
gen

randn ::
  -- | size
  [Int] ->
  -- | options
  TensorOptions ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
randn :: [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
randn = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory RandomGenFunc
LibTorch.randn_lGo

randn' ::
  -- | size
  [Int] ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
randn' :: [Int] -> Generator -> (Tensor, Generator)
randn' [Int]
size = [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
randn [Int]
size TensorOptions
defaultOpts

rand ::
  -- | size
  [Int] ->
  -- | options
  TensorOptions ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
rand :: [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
rand = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory RandomGenFunc
LibTorch.rand_lGo

rand' ::
  -- | size
  [Int] ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
rand' :: [Int] -> Generator -> (Tensor, Generator)
rand' [Int]
size = [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
rand [Int]
size TensorOptions
defaultOpts

randint ::
  -- | low
  Int ->
  -- | high
  Int ->
  -- | size
  [Int] ->
  -- | options
  TensorOptions ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
randint :: Int
-> Int
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
randint Int
low Int
high = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory (Int64 -> Int64 -> RandomGenFunc
LibTorch.randint_lllGo (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
low) (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
high))

randint' ::
  -- | low
  Int ->
  -- | high
  Int ->
  -- | size
  [Int] ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
randint' :: Int -> Int -> [Int] -> Generator -> (Tensor, Generator)
randint' Int
low Int
high [Int]
size = Int
-> Int
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
randint Int
low Int
high [Int]
size TensorOptions
defaultOpts

normal ::
  -- | mean
  Double ->
  -- | std
  Double ->
  -- | size
  [Int] ->
  -- | options
  TensorOptions ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
normal :: Double
-> Double
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
normal Double
mean Double
std = RandomGenFunc
-> [Int] -> TensorOptions -> Generator -> (Tensor, Generator)
generatorFactory (CDouble -> CDouble -> RandomGenFunc
LibTorch.normal_ddlGo (Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
mean) (Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac Double
std))

normal' ::
  -- | mean
  Double ->
  -- | std
  Double ->
  -- | size
  [Int] ->
  -- | generator
  Generator ->
  -- | output
  (Tensor, Generator)
normal' :: Double -> Double -> [Int] -> Generator -> (Tensor, Generator)
normal' Double
mean Double
std [Int]
size = Double
-> Double
-> [Int]
-> TensorOptions
-> Generator
-> (Tensor, Generator)
normal Double
mean Double
std [Int]
size TensorOptions
defaultOpts