1- {-# LANGUAGE NamedFieldPuns #-}
2- {-# LANGUAGE ScopedTypeVariables #-}
1+ {-# LANGUAGE DisambiguateRecordFields #-}
2+ {-# LANGUAGE NamedFieldPuns #-}
3+ {-# LANGUAGE ScopedTypeVariables #-}
4+ {-# LANGUAGE TupleSections #-}
35
46-- | The module should be imported qualified.
57--
@@ -16,19 +18,19 @@ module Ouroboros.Network.TxSubmission.Mempool.Simple
1618import Prelude hiding (read , seq )
1719
1820import Control.Concurrent.Class.MonadSTM.Strict
19-
21+ import Data.Bitraversable
22+ import Data.Either
2023import Data.Foldable (toList )
2124import Data.Foldable qualified as Foldable
22- import Data.Function (on )
23- import Data.List (find , nubBy )
25+ import Data.List (find )
2426import Data.Maybe (isJust )
2527import Data.Sequence (Seq )
2628import Data.Sequence qualified as Seq
2729import Data.Set (Set )
2830import Data.Set qualified as Set
2931
32+ import Ouroboros.Network.Protocol.LocalTxSubmission.Type (SubmitResult (.. ))
3033import Ouroboros.Network.SizeInBytes
31- import Ouroboros.Network.TxSubmission.Inbound.V2.Types
3234import Ouroboros.Network.TxSubmission.Mempool.Reader
3335
3436
@@ -98,31 +100,67 @@ getReader getTxId getTxSize (Mempool mempool) =
98100 f idx tx = (getTxId tx, idx, getTxSize tx)
99101
100102
101- -- | A simple mempool writer.
103+ -- | A mempool writer which generalizes the tx submission mempool writer
104+ --
105+ data MempoolWriter txid tx failure idx m =
106+ MempoolWriter {
107+
108+ -- | Compute the transaction id from a transaction.
109+ --
110+ -- This is used in the protocol handler to verify a full transaction
111+ -- matches a previously given transaction id.
112+ --
113+ txId :: tx -> txid ,
114+
115+ -- | Supply a batch of transactions to the mempool. They are either
116+ -- accepted or rejected individually, but in the order supplied.
117+ --
118+ -- The 'txid's of all transactions that were added successfully are
119+ -- returned.
120+ mempoolAddTxs :: [tx ] -> m [SubmitResult failure ]
121+ }
122+
123+
124+ -- | A mempool writer with validation harness
125+ -- PRECONDITION: no duplicates given to mempoolAddTxs
102126--
103- getWriter :: forall tx txid m .
127+ getWriter :: forall tx txid ctx tx' failure m .
104128 ( MonadSTM m
105129 , Ord txid
106130 )
107131 => (tx -> txid )
108- -> (tx -> Bool )
109- -- ^ validate a tx
110- -> Mempool m tx
111- -> TxSubmissionMempoolWriter txid tx Int m
112- getWriter getTxId validateTx (Mempool mempool) =
113- TxSubmissionMempoolWriter {
114- txId = getTxId,
115-
116- mempoolAddTxs = \ txs -> do
117- atomically $ do
118- mempoolTxs <- readTVar mempool
119- let currentIds = Set. fromList (map getTxId (toList mempoolTxs))
120- validTxs = nubBy (on (==) getTxId)
121- $ filter
122- (\ tx -> validateTx tx
123- && getTxId tx `Set.notMember` currentIds)
124- txs
125- mempoolTxs' = Foldable. foldl' (Seq. |>) mempoolTxs validTxs
126- writeTVar mempool mempoolTxs'
127- return (map getTxId validTxs)
128- }
132+ -- ^ get txid of a tx
133+ -> m ctx
134+ -- ^ monadic validation context, acquired once prior to all work
135+ -> (ctx -> tx -> tx' )
136+ -- ^ pre-process every transanction with the context
137+ -> (tx' -> Bool -> Either failure tx )
138+ -- ^ validate a tx in an atomic block, any failing `tx` throws an exception.
139+ -> (failure -> STM m failure )
140+ -- ^ return `True` when a failure should throw an exception
141+ -> Mempool m txid tx
142+ -> MempoolWriter txid tx failure Int m
143+ getWriter getTxId acquireCtx preProcess validateTx failureFilterFn (Mempool mempool) =
144+ MempoolWriter {
145+ txId = getTxId,
146+
147+ mempoolAddTxs = \ txs -> do
148+ ctx <- acquireCtx
149+ let txs' = preProcess ctx <$> txs -- TODO could run in parallel
150+ atomically $ do
151+ MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool
152+ result <- sequence
153+ [bimapM (fmap SubmitFail . failureFilterFn) (pure . (txid,)) validated
154+ | (tx, tx') <- zip txs txs'
155+ , let txid = getTxId tx
156+ validated =
157+ validateTx tx' (txid `Set.member` mempoolSet)
158+ ]
159+ let (validIds, validTxs) = unzip . rights $ result
160+ mempoolTxs' = MempoolSeq {
161+ mempoolSet = Set. union mempoolSet (Set. fromList validIds),
162+ mempoolSeq = Foldable. foldl' (Seq. |>) mempoolSeq validTxs
163+ }
164+ writeTVar mempool mempoolTxs'
165+ return $ fromLeft SubmitSuccess <$> result
166+ }
0 commit comments