{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Torch.Device where

import qualified Data.Int as I
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Type as ATen

data DeviceType = CPU | CUDA | MPS
  deriving (DeviceType -> DeviceType -> Bool
(DeviceType -> DeviceType -> Bool)
-> (DeviceType -> DeviceType -> Bool) -> Eq DeviceType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DeviceType -> DeviceType -> Bool
== :: DeviceType -> DeviceType -> Bool
$c/= :: DeviceType -> DeviceType -> Bool
/= :: DeviceType -> DeviceType -> Bool
Eq, Eq DeviceType
Eq DeviceType =>
(DeviceType -> DeviceType -> Ordering)
-> (DeviceType -> DeviceType -> Bool)
-> (DeviceType -> DeviceType -> Bool)
-> (DeviceType -> DeviceType -> Bool)
-> (DeviceType -> DeviceType -> Bool)
-> (DeviceType -> DeviceType -> DeviceType)
-> (DeviceType -> DeviceType -> DeviceType)
-> Ord DeviceType
DeviceType -> DeviceType -> Bool
DeviceType -> DeviceType -> Ordering
DeviceType -> DeviceType -> DeviceType
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: DeviceType -> DeviceType -> Ordering
compare :: DeviceType -> DeviceType -> Ordering
$c< :: DeviceType -> DeviceType -> Bool
< :: DeviceType -> DeviceType -> Bool
$c<= :: DeviceType -> DeviceType -> Bool
<= :: DeviceType -> DeviceType -> Bool
$c> :: DeviceType -> DeviceType -> Bool
> :: DeviceType -> DeviceType -> Bool
$c>= :: DeviceType -> DeviceType -> Bool
>= :: DeviceType -> DeviceType -> Bool
$cmax :: DeviceType -> DeviceType -> DeviceType
max :: DeviceType -> DeviceType -> DeviceType
$cmin :: DeviceType -> DeviceType -> DeviceType
min :: DeviceType -> DeviceType -> DeviceType
Ord, Int -> DeviceType -> ShowS
[DeviceType] -> ShowS
DeviceType -> String
(Int -> DeviceType -> ShowS)
-> (DeviceType -> String)
-> ([DeviceType] -> ShowS)
-> Show DeviceType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DeviceType -> ShowS
showsPrec :: Int -> DeviceType -> ShowS
$cshow :: DeviceType -> String
show :: DeviceType -> String
$cshowList :: [DeviceType] -> ShowS
showList :: [DeviceType] -> ShowS
Show)

data Device = Device {Device -> DeviceType
deviceType :: DeviceType, Device -> Int16
deviceIndex :: I.Int16}
  deriving (Device -> Device -> Bool
(Device -> Device -> Bool)
-> (Device -> Device -> Bool) -> Eq Device
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Device -> Device -> Bool
== :: Device -> Device -> Bool
$c/= :: Device -> Device -> Bool
/= :: Device -> Device -> Bool
Eq, Eq Device
Eq Device =>
(Device -> Device -> Ordering)
-> (Device -> Device -> Bool)
-> (Device -> Device -> Bool)
-> (Device -> Device -> Bool)
-> (Device -> Device -> Bool)
-> (Device -> Device -> Device)
-> (Device -> Device -> Device)
-> Ord Device
Device -> Device -> Bool
Device -> Device -> Ordering
Device -> Device -> Device
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Device -> Device -> Ordering
compare :: Device -> Device -> Ordering
$c< :: Device -> Device -> Bool
< :: Device -> Device -> Bool
$c<= :: Device -> Device -> Bool
<= :: Device -> Device -> Bool
$c> :: Device -> Device -> Bool
> :: Device -> Device -> Bool
$c>= :: Device -> Device -> Bool
>= :: Device -> Device -> Bool
$cmax :: Device -> Device -> Device
max :: Device -> Device -> Device
$cmin :: Device -> Device -> Device
min :: Device -> Device -> Device
Ord, Int -> Device -> ShowS
[Device] -> ShowS
Device -> String
(Int -> Device -> ShowS)
-> (Device -> String) -> ([Device] -> ShowS) -> Show Device
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Device -> ShowS
showsPrec :: Int -> Device -> ShowS
$cshow :: Device -> String
show :: Device -> String
$cshowList :: [Device] -> ShowS
showList :: [Device] -> ShowS
Show)

instance Castable DeviceType ATen.DeviceType where
  cast :: forall r. DeviceType -> (Int16 -> IO r) -> IO r
cast DeviceType
CPU Int16 -> IO r
f = Int16 -> IO r
f Int16
ATen.kCPU
  cast DeviceType
CUDA Int16 -> IO r
f = Int16 -> IO r
f Int16
ATen.kCUDA
  cast DeviceType
MPS Int16 -> IO r
f = Int16 -> IO r
f Int16
ATen.kMPS

  uncast :: forall r. Int16 -> (DeviceType -> IO r) -> IO r
uncast Int16
x DeviceType -> IO r
f
    | Int16
x Int16 -> Int16 -> Bool
forall a. Eq a => a -> a -> Bool
== Int16
ATen.kCPU = DeviceType -> IO r
f DeviceType
CPU
    | Int16
x Int16 -> Int16 -> Bool
forall a. Eq a => a -> a -> Bool
== Int16
ATen.kCUDA = DeviceType -> IO r
f DeviceType
CUDA
    | Int16
x Int16 -> Int16 -> Bool
forall a. Eq a => a -> a -> Bool
== Int16
ATen.kMPS = DeviceType -> IO r
f DeviceType
MPS