{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Torch.Backend where

import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Type as ATen

data Backend = CPU | CUDA | HIP | SparseCPU | SparseCUDA | XLA | MPS
  deriving (Backend -> Backend -> Bool
(Backend -> Backend -> Bool)
-> (Backend -> Backend -> Bool) -> Eq Backend
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Backend -> Backend -> Bool
== :: Backend -> Backend -> Bool
$c/= :: Backend -> Backend -> Bool
/= :: Backend -> Backend -> Bool
Eq, Int -> Backend -> ShowS
[Backend] -> ShowS
Backend -> String
(Int -> Backend -> ShowS)
-> (Backend -> String) -> ([Backend] -> ShowS) -> Show Backend
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Backend -> ShowS
showsPrec :: Int -> Backend -> ShowS
$cshow :: Backend -> String
show :: Backend -> String
$cshowList :: [Backend] -> ShowS
showList :: [Backend] -> ShowS
Show)

instance Castable Backend ATen.Backend where
  cast :: forall r. Backend -> (Backend -> IO r) -> IO r
cast Backend
CPU Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bCPU
  cast Backend
CUDA Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bCUDA
  cast Backend
MPS Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bMPS
  cast Backend
HIP Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bHIP
  cast Backend
SparseCPU Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bSparseCPU
  cast Backend
SparseCUDA Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bSparseCUDA
  cast Backend
XLA Backend -> IO r
f = Backend -> IO r
f Backend
ATen.bXLA

  uncast :: forall r. Backend -> (Backend -> IO r) -> IO r
uncast Backend
x Backend -> IO r
f
    | Backend
x Backend -> Backend -> Bool
forall a. Eq a => a -> a -> Bool
== Backend
ATen.bCPU = Backend -> IO r
f Backend
CPU
    | Backend
x Backend -> Backend -> Bool
forall a. Eq a => a -> a -> Bool
== Backend
ATen.bCUDA = Backend -> IO r
f Backend
CUDA
    | Backend
x Backend -> Backend -> Bool
forall a. Eq a => a -> a -> Bool
== Backend
ATen.bMPS = Backend -> IO r
f Backend
MPS
    | Backend
x Backend -> Backend -> Bool
forall a. Eq a => a -> a -> Bool
== Backend
ATen.bHIP = Backend -> IO r
f Backend
HIP
    | Backend
x Backend -> Backend -> Bool
forall a. Eq a => a -> a -> Bool
== Backend
ATen.bSparseCPU = Backend -> IO r
f Backend
SparseCPU
    | Backend
x Backend -> Backend -> Bool
forall a. Eq a => a -> a -> Bool
== Backend
ATen.bSparseCUDA = Backend -> IO r
f Backend
SparseCUDA
    | Backend
x Backend -> Backend -> Bool
forall a. Eq a => a -> a -> Bool
== Backend
ATen.bXLA = Backend -> IO r
f Backend
XLA