module Network.Socket.BufferPool.Buffer (
    newBufferPool
  , withBufferPool
  , mallocBS
  , copy
  ) where

import qualified Data.ByteString as BS
import Data.ByteString.Internal (ByteString(..), memcpy)
import Data.ByteString.Unsafe (unsafeTake, unsafeDrop)
import Data.IORef (newIORef, readIORef, writeIORef)
import Foreign.ForeignPtr
import Foreign.Marshal.Alloc (mallocBytes, finalizerFree)
import Foreign.Ptr (castPtr, plusPtr)

import Network.Socket.BufferPool.Types

----------------------------------------------------------------

-- | Creating a buffer pool.
--   The first argument is the lower limit.
--   When the size of the buffer in the poll is lower than this limit,
--   the buffer is thrown awany (and is eventually freed).
--   Then a new buffer is allocated.
--   The second argument is the size for the new allocation.
newBufferPool :: Int -> Int -> IO BufferPool
newBufferPool :: Int -> Int -> IO BufferPool
newBufferPool Int
l Int
h = Int -> Int -> IORef ByteString -> BufferPool
BufferPool Int
l Int
h (IORef ByteString -> BufferPool)
-> IO (IORef ByteString) -> IO BufferPool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
BS.empty

----------------------------------------------------------------

-- | Using a buffer pool.
--   The second argument is a function which returns
--   how many bytes are filled in the buffer.
--   The buffer in the buffer pool is automatically managed.
withBufferPool :: BufferPool -> (Buffer -> BufSize -> IO Int) -> IO ByteString
withBufferPool :: BufferPool -> (Buffer -> Int -> IO Int) -> IO ByteString
withBufferPool (BufferPool Int
l Int
h IORef ByteString
ref) Buffer -> Int -> IO Int
f = do
    buf0 <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef IORef ByteString
ref
    buf  <- if BS.length buf0 >= l then return buf0
                                   else mallocBS h
    consumed <- withForeignBuffer buf f
    writeIORef ref $ unsafeDrop consumed buf
    return $ unsafeTake consumed buf

withForeignBuffer :: ByteString -> (Buffer -> BufSize -> IO Int) -> IO Int
withForeignBuffer :: ByteString -> (Buffer -> Int -> IO Int) -> IO Int
withForeignBuffer (PS ForeignPtr Word8
ps Int
s Int
l) Buffer -> Int -> IO Int
f = ForeignPtr Word8 -> (Buffer -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
ps ((Buffer -> IO Int) -> IO Int) -> (Buffer -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Buffer
p -> Buffer -> Int -> IO Int
f (Buffer -> Ptr Any
forall a b. Ptr a -> Ptr b
castPtr Buffer
p Ptr Any -> Int -> Buffer
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
s) Int
l
{-# INLINE withForeignBuffer #-}

----------------------------------------------------------------

-- | Allocating a byte string.
mallocBS :: Int -> IO ByteString
mallocBS :: Int -> IO ByteString
mallocBS Int
size = do
    ptr <- Int -> IO Buffer
forall a. Int -> IO (Ptr a)
mallocBytes Int
size
    fptr <- newForeignPtr finalizerFree ptr
    return $ PS fptr 0 size
{-# INLINE mallocBS #-}

-- | Copying the bytestring to the buffer.
--   This function returns the point where the next copy should start.
copy :: Buffer -> ByteString -> IO Buffer
copy :: Buffer -> ByteString -> IO Buffer
copy Buffer
ptr (PS ForeignPtr Word8
fp Int
o Int
l) = ForeignPtr Word8 -> (Buffer -> IO Buffer) -> IO Buffer
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Buffer -> IO Buffer) -> IO Buffer)
-> (Buffer -> IO Buffer) -> IO Buffer
forall a b. (a -> b) -> a -> b
$ \Buffer
p -> do
    Buffer -> Buffer -> Int -> IO ()
memcpy Buffer
ptr (Buffer
p Buffer -> Int -> Buffer
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
o) (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l)
    Buffer -> IO Buffer
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Buffer -> IO Buffer) -> Buffer -> IO Buffer
forall a b. (a -> b) -> a -> b
$ Buffer
ptr Buffer -> Int -> Buffer
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
l
{-# INLINE copy #-}