-- | The encryption scheme used for DTB files found on game discs.
module Data.DTA.Crypt
( oldCrypt, newCrypt
, decrypt, encrypt
, decryptFile, encryptFile
, decryptHandle, encryptHandle
, Key, Crypt
) where

import           Control.Monad         (forM_, liftM2, liftM3)
import           Control.Monad.ST.Lazy (ST, runST)
import           Data.Array.ST         (STArray, newArray, readArray,
                                        writeArray)
import           Data.Bits             (shiftR, xor, (.&.), (.|.))
import           Data.STRef.Lazy       (STRef, newSTRef, readSTRef, writeSTRef)
import           Data.Word             (Word32, Word8)
import           System.IO             (Handle)

import           Data.Binary.Get       (getRemainingLazyByteString, getWord32le,
                                        runGet)
import           Data.Binary.Put       (putLazyByteString, putWord32le, runPut)
import qualified Data.ByteString.Lazy  as BL

-- | An encryption/decryption key.
type Key = Word32

-- | Using a key to generate an infinite stream of crypt bytes.
type Crypt = Key -> [Word8]

{- |
The way both the new and old DTB encryption algorithms work is by using the key
to generate a stream of bytes. Each of these bytes is then XOR'd with the
corresponding bytes in the source file. The same algorithm is both the
decryption and encryption; this is because @(A xor B) xor B == A@.
-}
crypt :: Crypt -> Key -> BL.ByteString -> BL.ByteString
crypt :: Crypt -> Key -> ByteString -> ByteString
crypt Crypt
cry Key
key = [Word8] -> ByteString
BL.pack ([Word8] -> ByteString)
-> (ByteString -> [Word8]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Word8 -> Word8) -> [Word8] -> [Word8] -> [Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor (Crypt
cry Key
key) ([Word8] -> [Word8])
-> (ByteString -> [Word8]) -> ByteString -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
BL.unpack

-- | Take the first four bytes of the string as the key, and decrypt the rest of
-- the file.
decrypt :: Crypt -> BL.ByteString -> BL.ByteString
decrypt :: Crypt -> ByteString -> ByteString
decrypt Crypt
cry = Get ByteString -> ByteString -> ByteString
forall a. Get a -> ByteString -> a
runGet (Get ByteString -> ByteString -> ByteString)
-> Get ByteString -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ (Key -> ByteString -> ByteString)
-> Get Key -> Get ByteString -> Get ByteString
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 (Crypt -> Key -> ByteString -> ByteString
crypt Crypt
cry) Get Key
getWord32le Get ByteString
getRemainingLazyByteString

-- | Encrypt a string with a key, and append the key to the encrypted string.
encrypt :: Crypt -> Key -> BL.ByteString -> BL.ByteString
encrypt :: Crypt -> Key -> ByteString -> ByteString
encrypt Crypt
cry Key
key ByteString
input
  = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$ Key -> Put
putWord32le Key
key Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putLazyByteString (Crypt -> Key -> ByteString -> ByteString
crypt Crypt
cry Key
key ByteString
input)

-- | Decrypt an encrypted DTB file using the given crypt method.
decryptFile :: Crypt -> FilePath -> FilePath -> IO ()
decryptFile :: Crypt -> FilePath -> FilePath -> IO ()
decryptFile Crypt
cry FilePath
fi FilePath
fo = FilePath -> IO ByteString
BL.readFile FilePath
fi IO ByteString -> (ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= FilePath -> ByteString -> IO ()
BL.writeFile FilePath
fo (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Crypt -> ByteString -> ByteString
decrypt Crypt
cry

-- | Encrypt an unencrypted DTB file using the given crypt method and key.
encryptFile :: Crypt -> Key -> FilePath -> FilePath -> IO ()
encryptFile :: Crypt -> Key -> FilePath -> FilePath -> IO ()
encryptFile Crypt
cry Key
key FilePath
fi FilePath
fo = FilePath -> IO ByteString
BL.readFile FilePath
fi IO ByteString -> (ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= FilePath -> ByteString -> IO ()
BL.writeFile FilePath
fo (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Crypt -> Key -> ByteString -> ByteString
encrypt Crypt
cry Key
key

-- | Decrypt an encrypted DTB file across two handles.
decryptHandle :: Crypt -> Handle -> Handle -> IO ()
decryptHandle :: Crypt -> Handle -> Handle -> IO ()
decryptHandle Crypt
cry Handle
hi Handle
ho = Handle -> IO ByteString
BL.hGetContents Handle
hi IO ByteString -> (ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Handle -> ByteString -> IO ()
BL.hPutStr Handle
ho (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Crypt -> ByteString -> ByteString
decrypt Crypt
cry

-- | Encrypt an unencrypted DTB files across two handles.
encryptHandle :: Crypt -> Key -> Handle -> Handle -> IO ()
encryptHandle :: Crypt -> Key -> Handle -> Handle -> IO ()
encryptHandle Crypt
cry Key
key Handle
hi Handle
ho = Handle -> IO ByteString
BL.hGetContents Handle
hi IO ByteString -> (ByteString -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Handle -> ByteString -> IO ()
BL.hPutStr Handle
ho (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Crypt -> Key -> ByteString -> ByteString
encrypt Crypt
cry Key
key

-- New (Rock Band) encryption

{-
From xorloser's dtbcrypt:
unsigned int dtb_xor_x360(unsigned int data)
{
  int val1 = (data / 0x1F31D) * 0xB14;
  int val2 = (data - ((data / 0x1F31D) * 0x1F31D)) * 0x41A7;
  val2 = val2 - val1;
  if(val2 <= 0)
    val2 += 0x7FFFFFFF;
  return val2;
}
-}

-- | The key iteration function for new DTB encryption/decryption.
dtbXor360 :: Word32 -> Word32
dtbXor360 :: Key -> Key
dtbXor360 Key
d = let
  q :: Key
q = Key -> Key -> Key
forall a. Integral a => a -> a -> a
quot Key
d Key
0x1F31D
  v :: Key
v = (Key
d Key -> Key -> Key
forall a. Num a => a -> a -> a
- (Key
q Key -> Key -> Key
forall a. Num a => a -> a -> a
* Key
0x1F31D)) Key -> Key -> Key
forall a. Num a => a -> a -> a
* Key
0x41A7 Key -> Key -> Key
forall a. Num a => a -> a -> a
- Key
q Key -> Key -> Key
forall a. Num a => a -> a -> a
* Key
0xB14 in
    if Key
v Key -> Key -> Bool
forall a. Ord a => a -> a -> Bool
> Key
0x7FFFFFFF then Key
v Key -> Key -> Key
forall a. Num a => a -> a -> a
+ Key
0x7FFFFFFF else Key
v

-- | The lazy infinite list of crypt bytes for new-style encryption.
newCrypt :: Crypt
newCrypt :: Crypt
newCrypt Key
key = (Key -> Word8) -> [Key] -> [Word8]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Key -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Key] -> [Word8]) -> [Key] -> [Word8]
forall a b. (a -> b) -> a -> b
$ [Key] -> [Key]
forall a. [a] -> [a]
tail ([Key] -> [Key]) -> [Key] -> [Key]
forall a b. (a -> b) -> a -> b
$ (Key -> Key) -> Key -> [Key]
forall a. (a -> a) -> a -> [a]
iterate Key -> Key
dtbXor360 Key
key

-- Old (PS2 Guitar Hero) encryption
-- This algorithm uses a large table to produce the XOR bytes, which is
-- implemented using an STArray in the lazy ST monad.

data CryptTable s = CryptTable
  { CryptTable s -> STRef s Word8
idx1  :: STRef s Word8
  , CryptTable s -> STRef s Word8
idx2  :: STRef s Word8
  , CryptTable s -> STArray s Word8 Key
table :: STArray s Word8 Word32 }

cryptTable :: Key -> ST s (CryptTable s)
cryptTable :: Key -> ST s (CryptTable s)
cryptTable Key
key = do
  STRef s Key
v1ref <- Key -> ST s (STRef s Key)
forall a s. a -> ST s (STRef s a)
newSTRef Key
key
  STArray s Word8 Key
tbl <- (Word8, Word8) -> Key -> ST s (STArray s Word8 Key)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Word8
0, Word8
0xF8) Key
0
  [Word8] -> (Word8 -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Word8
0..Word8
0xF8] ((Word8 -> ST s ()) -> ST s ()) -> (Word8 -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Word8
i -> do
    Key
v1 <- STRef s Key -> ST s Key
forall s a. STRef s a -> ST s a
readSTRef STRef s Key
v1ref
    let v2 :: Key
v2 = (Key
v1 Key -> Key -> Key
forall a. Num a => a -> a -> a
* Key
0x41C64E6D) Key -> Key -> Key
forall a. Num a => a -> a -> a
+ Key
0x3039
    let v1' :: Key
v1' = (Key
v2 Key -> Key -> Key
forall a. Num a => a -> a -> a
* Key
0x41C64E6D) Key -> Key -> Key
forall a. Num a => a -> a -> a
+ Key
0x3039
    STRef s Key -> Key -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Key
v1ref Key
v1'
    STArray s Word8 Key -> Word8 -> Key -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Word8 Key
tbl Word8
i (Key -> ST s ()) -> Key -> ST s ()
forall a b. (a -> b) -> a -> b
$ (Key
v1' Key -> Key -> Key
forall a. Bits a => a -> a -> a
.&. Key
0x7FFF0000) Key -> Key -> Key
forall a. Bits a => a -> a -> a
.|. (Key
v2 Key -> Int -> Key
forall a. Bits a => a -> Int -> a
`shiftR` Int
16)
  (STRef s Word8
 -> STRef s Word8 -> STArray s Word8 Key -> CryptTable s)
-> ST s (STRef s Word8)
-> ST s (STRef s Word8)
-> ST s (STArray s Word8 Key)
-> ST s (CryptTable s)
forall (m :: * -> *) a1 a2 a3 r.
Monad m =>
(a1 -> a2 -> a3 -> r) -> m a1 -> m a2 -> m a3 -> m r
liftM3 STRef s Word8
-> STRef s Word8 -> STArray s Word8 Key -> CryptTable s
forall s.
STRef s Word8
-> STRef s Word8 -> STArray s Word8 Key -> CryptTable s
CryptTable (Word8 -> ST s (STRef s Word8)
forall a s. a -> ST s (STRef s a)
newSTRef Word8
0) (Word8 -> ST s (STRef s Word8)
forall a s. a -> ST s (STRef s a)
newSTRef Word8
0x67) (STArray s Word8 Key -> ST s (STArray s Word8 Key)
forall (m :: * -> *) a. Monad m => a -> m a
return STArray s Word8 Key
tbl)

oldNext :: CryptTable s -> ST s Word8
oldNext :: CryptTable s -> ST s Word8
oldNext (CryptTable { idx1 :: forall s. CryptTable s -> STRef s Word8
idx1 = STRef s Word8
i1ref, idx2 :: forall s. CryptTable s -> STRef s Word8
idx2 = STRef s Word8
i2ref, table :: forall s. CryptTable s -> STArray s Word8 Key
table = STArray s Word8 Key
tbl }) = do
  Word8
i1 <- STRef s Word8 -> ST s Word8
forall s a. STRef s a -> ST s a
readSTRef STRef s Word8
i1ref
  Word8
i2 <- STRef s Word8 -> ST s Word8
forall s a. STRef s a -> ST s a
readSTRef STRef s Word8
i2ref
  Key
next <- (Key -> Key -> Key) -> ST s Key -> ST s Key -> ST s Key
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Key -> Key -> Key
forall a. Bits a => a -> a -> a
xor (STArray s Word8 Key -> Word8 -> ST s Key
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Word8 Key
tbl Word8
i1) (STArray s Word8 Key -> Word8 -> ST s Key
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Word8 Key
tbl Word8
i2)
  STArray s Word8 Key -> Word8 -> Key -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Word8 Key
tbl Word8
i1 Key
next
  STRef s Word8 -> Word8 -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Word8
i1ref (Word8 -> ST s ()) -> Word8 -> ST s ()
forall a b. (a -> b) -> a -> b
$ if Word8
i1 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0xF8 then Word8
0 else Word8
i1 Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ Word8
1
  STRef s Word8 -> Word8 -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Word8
i2ref (Word8 -> ST s ()) -> Word8 -> ST s ()
forall a b. (a -> b) -> a -> b
$ if Word8
i2 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0xF8 then Word8
0 else Word8
i2 Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ Word8
1
  Word8 -> ST s Word8
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8 -> ST s Word8) -> Word8 -> ST s Word8
forall a b. (a -> b) -> a -> b
$ Key -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Key
next
  -- The fromIntegral chops each Word32 down to the least significant byte.

-- | The lazy infinite list of crypt bytes for old-style encryption.
oldCrypt :: Crypt
oldCrypt :: Crypt
oldCrypt Key
key = (forall s. ST s [Word8]) -> [Word8]
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s [Word8]) -> [Word8])
-> (forall s. ST s [Word8]) -> [Word8]
forall a b. (a -> b) -> a -> b
$ Key -> ST s (CryptTable s)
forall s. Key -> ST s (CryptTable s)
cryptTable Key
key ST s (CryptTable s)
-> (CryptTable s -> ST s [Word8]) -> ST s [Word8]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [ST s Word8] -> ST s [Word8]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([ST s Word8] -> ST s [Word8])
-> (CryptTable s -> [ST s Word8]) -> CryptTable s -> ST s [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ST s Word8 -> [ST s Word8]
forall a. a -> [a]
repeat (ST s Word8 -> [ST s Word8])
-> (CryptTable s -> ST s Word8) -> CryptTable s -> [ST s Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CryptTable s -> ST s Word8
forall s. CryptTable s -> ST s Word8
oldNext