{-# OPTIONS_HADDOCK hide #-}

module Data.GeoIP2.SearchTree where

import           Data.Bits       (shift, testBit, (.&.), (.|.))
import qualified Data.ByteString as BS
import           Data.Int
import           Data.IP         (IP (..), fromIPv4, fromIPv6b)
import           Data.Word

-- | Convert byte to list of bits starting from the most significant one
byteToBits :: Int -> [Bool]
byteToBits :: Int -> [Bool]
byteToBits Int
b = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
b) [Int
7,Int
6..Int
0]

-- | Convert IP address to bits
ipToBits :: IP -> [Bool]
ipToBits :: IP -> [Bool]
ipToBits (IPv4 IPv4
addr) = (Int -> [Bool]) -> [Int] -> [Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Int -> [Bool]
byteToBits (IPv4 -> [Int]
fromIPv4 IPv4
addr)
ipToBits (IPv6 IPv6
addr) = (Int -> [Bool]) -> [Int] -> [Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Int -> [Bool]
byteToBits (IPv6 -> [Int]
fromIPv6b IPv6
addr)

-- | Read node (2 records) given the index of a node
readNode :: BS.ByteString -> Int -> Int64 -> (Int64, Int64)
readNode :: ByteString -> Int -> Int64 -> (Int64, Int64)
readNode ByteString
mem Int
recordbits Int64
index =
  let
    bytecount :: Int64
bytecount = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> Int -> Int64
forall a b. (a -> b) -> a -> b
$ Int
recordbits Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
4
    bytes :: ByteString
bytes = Int -> ByteString -> ByteString
BS.take (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
bytecount) (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.drop (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ Int64
index Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
bytecount) ByteString
mem
    num :: Word64
num = (Word64 -> Word8 -> Word64) -> Word64 -> ByteString -> Word64
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl' (\Word64
acc Word8
new -> Word8 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
new Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
256 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
acc) Word64
0 ByteString
bytes :: Word64
    -- 28 bits has a strange record format
    left28 :: Word64
left28 = Word64
num Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shift` (-Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word64
num Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xf0000000) Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shift` (-Int
4)
  in case Int
recordbits of
      Int
28 -> (Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
left28, Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
num Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. ((Word64
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shift` Int
recordbits) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1)))
      Int
_  -> (Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
num Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shift` Int -> Int
forall a. Num a => a -> a
negate Int
recordbits), Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
num Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. ((Word64
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shift` Int
recordbits) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
1)))

-- | Get offset in the Data Section
getDataOffset :: (BS.ByteString, Int64, Int) -> [Bool] -> Either String Int64
getDataOffset :: (ByteString, Int64, Int) -> [Bool] -> Either String Int64
getDataOffset (ByteString
mem, Int64
nodeCount, Int
recordSize) [Bool]
startbits =
  [Bool] -> Int64 -> Either String Int64
getnode [Bool]
startbits Int64
0
  where
    getnode :: [Bool] -> Int64 -> Either String Int64
getnode [Bool]
_ Int64
index
      | Int64
index Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
nodeCount = String -> Either String Int64
forall a b. a -> Either a b
Left String
"Information for address does not exist."
      | Int64
index Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
> Int64
nodeCount = Int64 -> Either String Int64
forall (m :: * -> *) a. Monad m => a -> m a
return (Int64 -> Either String Int64) -> Int64 -> Either String Int64
forall a b. (a -> b) -> a -> b
$ Int64
index Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
nodeCount Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
16
    getnode [] Int64
_ = String -> Either String Int64
forall a b. a -> Either a b
Left String
"IP address too short"
    getnode (Bool
bit:[Bool]
rest) Int64
index = [Bool] -> Int64 -> Either String Int64
getnode [Bool]
rest Int64
nextOffset
      where
        (Int64
left, Int64
right) = ByteString -> Int -> Int64 -> (Int64, Int64)
readNode ByteString
mem Int
recordSize Int64
index
        nextOffset :: Int64
nextOffset = if Bool
bit then Int64
right else Int64
left