Skip to content

Commit e90e707

Browse files
feat(SPV-748): add response to getHeaders
1 parent a2d198d commit e90e707

File tree

9 files changed

+382
-65
lines changed

9 files changed

+382
-65
lines changed

database/repository/header_repository.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,30 @@ func (r *HeaderRepository) GetChainBetweenTwoHashes(low string, high string) ([]
169169
}
170170
return nil, err
171171
}
172+
173+
// GetHeadersStartHeight returns height of the highest header from the list of hashes.
174+
func (r *HeaderRepository) GetHeadersStartHeight(hashtable []string) (int, error) {
175+
sh, err := r.db.GetHeadersStartHeight(hashtable)
176+
if err != nil {
177+
return 0, err
178+
}
179+
return sh, nil
180+
}
181+
182+
// GetHeadersByHeightRange returns headers from db in specified height range.
183+
func (r *HeaderRepository) GetHeadersByHeightRange(from int, to int) ([]*domains.BlockHeader, error) {
184+
bh, err := r.db.GetHeadersByHeightRange(from, to)
185+
if err != nil {
186+
return nil, err
187+
}
188+
return dto.ConvertToBlockHeader(bh), nil
189+
}
190+
191+
// GetHeadersStopHeight returns height of hashstop header from db.
192+
func (r *HeaderRepository) GetHeadersStopHeight(hashStop string) (int, error) {
193+
hs, err := r.db.GetHeadersStopHeight(hashStop)
194+
if err != nil {
195+
return 0, err
196+
}
197+
return hs, nil
198+
}

database/sql/headers.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import (
1616
const (
1717
HeadersTableName = "headers"
1818

19+
longestChainState = "LONGEST_CHAIN"
20+
1921
sqlInsertHeader = `
2022
INSERT INTO headers(hash, height, version, merkleroot, nonce, bits, header_state, chainwork, previous_block, timestamp , cumulated_work)
2123
VALUES(:hash, :height, :version, :merkleroot, :nonce, :bits, :header_state, :chainwork, :previous_block, :timestamp, :cumulated_work)
@@ -34,6 +36,12 @@ const (
3436
WHERE hash = ?
3537
`
3638

39+
sqlHeaderHeightFromHashAndState = `
40+
SELECT height
41+
FROM headers
42+
WHERE hash = ? AND header_state = ?
43+
`
44+
3745
sqlHeaderByHeight = `
3846
SELECT hash, height, version, merkleroot, nonce, bits, chainwork, previous_block, timestamp, header_state, cumulated_work
3947
FROM headers
@@ -160,6 +168,20 @@ const (
160168
sqlTipOfChainHeight = `SELECT MAX(height) FROM headers WHERE header_state = 'LONGEST_CHAIN'`
161169

162170
sqlVerifyHash = `SELECT hash FROM headers WHERE merkleroot = $1 AND height = $2 AND header_state = 'LONGEST_CHAIN'`
171+
172+
sqlGetHeadersHeight = `
173+
SELECT COALESCE(MAX(height), 0) AS startHeight
174+
FROM headers
175+
WHERE header_state = 'LONGEST_CHAIN'
176+
AND hash IN (?)
177+
`
178+
179+
sqlHeaderByHeightRangeLongestChain = `
180+
SELECT
181+
hash, height, version, merkleroot, nonce, bits, chainwork, previous_block, timestamp, header_state, cumulated_work
182+
FROM headers
183+
WHERE height BETWEEN ? AND ? AND header_state = 'LONGEST_CHAIN';
184+
`
163185
)
164186

165187
// HeadersDb represents a database connection and map of related sql queries.
@@ -396,6 +418,45 @@ func (h *HeadersDb) GetMerkleRootsConfirmations(
396418
return confirmations, nil
397419
}
398420

421+
// GetHashStartHeight returns hash and height from db with given locators.
422+
func (h *HeadersDb) GetHeadersStartHeight(hashTable []string) (int, error) {
423+
query, args, err := sqlx.In(sqlGetHeadersHeight, hashTable)
424+
if err != nil {
425+
h.log.Error().Err(err).Msg("Error while constructing query")
426+
return 0, err
427+
}
428+
429+
var heightStart int
430+
if err := h.db.Get(&heightStart, h.db.Rebind(query), args...); err != nil {
431+
h.log.Error().Err(err).Msg("Failed to get headers by locators")
432+
return 0, err
433+
}
434+
435+
return heightStart, nil
436+
}
437+
438+
// GetHeadersStopHeight will return header from db with given hash.
439+
func (h *HeadersDb) GetHeadersStopHeight(hashStop string) (int, error) {
440+
var dbHashStopHeight int
441+
if err := h.db.Get(&dbHashStopHeight, h.db.Rebind(sqlHeaderHeightFromHashAndState), hashStop, longestChainState); err != nil {
442+
if errors.Is(err, sql.ErrNoRows) {
443+
return 0, nil
444+
}
445+
return 0, errors.Wrapf(err, "failed to get stophash %s", hashStop)
446+
}
447+
448+
return dbHashStopHeight, nil
449+
}
450+
451+
// GetHeadersByHeightRange returns headers from db in specified height range.
452+
func (h *HeadersDb) GetHeadersByHeightRange(from int, to int) ([]*dto.DbBlockHeader, error) {
453+
var listOfHeaders []*dto.DbBlockHeader
454+
if err := h.db.Select(&listOfHeaders, h.db.Rebind(sqlHeaderByHeightRangeLongestChain), from, to); err != nil {
455+
return nil, errors.Wrapf(err, "failed to get headers using given range from: %d to: %d", from, to)
456+
}
457+
return listOfHeaders, nil
458+
}
459+
399460
func (h *HeadersDb) getChainTipHeight() (int32, error) {
400461
var tipHeight int32
401462
err := h.db.Get(&tipHeight, sqlTipOfChainHeight)

internal/tests/fixtures/blockheader_util.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@ const (
1313
DefaultChainWork = 4295032833
1414
)
1515

16-
// HashOf returns chainhash.Hash representation of string, ignoring errors.
16+
// HashOf returns chainhash.Hash representation of string, panic when error occurs.
1717
func HashOf(s string) *chainhash.Hash {
18-
h, _ := chainhash.NewHashFromStr(s)
18+
h, err := chainhash.NewHashFromStr(s)
19+
if err != nil {
20+
panic("Invalid hash string")
21+
}
1922
return h
2023
}
2124

internal/tests/testrepository/header_testrepository.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,39 @@ func (r *HeaderTestRepository) GetMerkleRootsConfirmations(
238238
return mrcfs, nil
239239
}
240240

241+
func (r *HeaderTestRepository) GetHeadersStartHeight(hashtable []string) (int, error) {
242+
for i := len(*r.db) - 1; i >= 0; i-- {
243+
header := (*r.db)[i]
244+
for j := len(hashtable) - 1; j >= 0; j-- {
245+
if header.Hash.String() == hashtable[j] && header.State == domains.LongestChain {
246+
return int(header.Height), nil
247+
}
248+
}
249+
}
250+
return 0, nil
251+
}
252+
253+
func (r *HeaderTestRepository) GetHeadersByHeightRange(from int, to int) ([]*domains.BlockHeader, error) {
254+
filteredHeaders := make([]*domains.BlockHeader, 0)
255+
for _, header := range *r.db {
256+
if header.Height >= int32(from) && header.Height <= int32(to) {
257+
headerCopy := header
258+
filteredHeaders = append(filteredHeaders, &headerCopy)
259+
}
260+
}
261+
return filteredHeaders, nil
262+
}
263+
264+
func (r *HeaderTestRepository) GetHeadersStopHeight(hashStop string) (int, error) {
265+
for i := len(*r.db) - 1; i >= 0; i-- {
266+
header := (*r.db)[i]
267+
if header.Hash.String() == hashStop {
268+
return int(header.Height), nil
269+
}
270+
}
271+
return 0, errors.New("could not find stop height")
272+
}
273+
241274
// FillWithLongestChain fills the test header repository
242275
// with 4 additional blocks to create a longest chain.
243276
func (r *HeaderTestRepository) FillWithLongestChain() {

internal/transports/p2p/peer/peer.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ func (p *Peer) readMsgHandler() {
259259
p.handleHeadersMsg(msg)
260260
case *wire.MsgInv:
261261
p.handleInvMsg(msg)
262+
case *wire.MsgGetHeaders:
263+
p.handleGetHeadersMsg(msg)
262264
default:
263265
p.log.Info().Msgf("received msg of type: %T", msg)
264266
}
@@ -555,6 +557,24 @@ func (p *Peer) handleHeadersMsg(msg *wire.MsgHeaders) {
555557
}
556558
}
557559

560+
func (p *Peer) handleGetHeadersMsg(msg *wire.MsgGetHeaders) {
561+
p.log.Info().Msgf("received getheaders msg from peer %s", p)
562+
if !p.syncedCheckpoints {
563+
p.log.Info().Msgf("we are still syncing, ignoring getHeaders msg from peer %s", p)
564+
return
565+
}
566+
567+
bh, err := p.headersService.LocateHeadersGetHeaders(msg.BlockLocatorHashes, &msg.HashStop)
568+
if err != nil {
569+
p.log.Error().Msgf("error locating headers for getheaders msg from peer %s, reason: %v", p, err)
570+
return
571+
}
572+
573+
msgHeaders := wire.NewMsgHeaders()
574+
msgHeaders.Headers = bh
575+
p.queueMessage(msgHeaders)
576+
}
577+
558578
func (p *Peer) switchToSendHeadersMode() {
559579
if !p.sendHeadersMode && p.protocolVersion >= wire.SendHeadersVersion {
560580
p.log.Info().Msgf("switching to send headers mode - requesting peer %s to send us headers directly instead of inv msg", p)

repository/repository.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ type Headers interface {
2525
GetAllTips() ([]*domains.BlockHeader, error)
2626
GetAncestorOnHeight(hash string, height int32) (*domains.BlockHeader, error)
2727
GetChainBetweenTwoHashes(low string, high string) ([]*domains.BlockHeader, error)
28+
GetHeadersStartHeight(hashtable []string) (int, error)
29+
GetHeadersByHeightRange(from int, to int) ([]*domains.BlockHeader, error)
30+
GetHeadersStopHeight(hashStop string) (int, error)
2831
}
2932

3033
// Tokens is a interface which represents methods performed on tokens table in defined storage.

service/header_service.go

Lines changed: 67 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -291,86 +291,90 @@ func (hs *HeaderService) GetMerkleRootsConfirmations(
291291
return hs.repo.Headers.GetMerkleRootsConfirmations(request, hs.merkleCfg.MaxBlockHeightExcess)
292292
}
293293

294-
// LocateHeaders fetches headers for a number of blocks after the most recent known block
295-
// in the best chain, based on the provided block locator and stop hash, and defaults to the
296-
// genesis block if the locator is unknown.
297-
func (hs *HeaderService) LocateHeaders(locator domains.BlockLocator, hashStop *chainhash.Hash) []wire.BlockHeader {
298-
headers := hs.locateHeaders(locator, hashStop, wire.MaxBlockHeadersPerMsg)
299-
return headers
294+
// LocateHeadersGetHeaders returns headers with given hashes.
295+
func (hs *HeaderService) LocateHeadersGetHeaders(locators []*chainhash.Hash, hashstop *chainhash.Hash) ([]*wire.BlockHeader, error) {
296+
headers, err := hs.locateHeadersGetHeaders(locators, hashstop)
297+
if err != nil {
298+
return nil, err
299+
}
300+
return headers, nil
300301
}
301302

302-
func (hs *HeaderService) locateHeaders(locator domains.BlockLocator, hashStop *chainhash.Hash, maxHeaders uint32) []wire.BlockHeader {
303-
// Find the node after the first known block in the locator and the
304-
// total number of nodes after it needed while respecting the stop hash
305-
// and max entries.
306-
node, total := hs.locateInventory(locator, hashStop, maxHeaders)
307-
if total == 0 {
308-
return nil
303+
func (hs *HeaderService) locateHeadersGetHeaders(locators []*chainhash.Hash, hashstop *chainhash.Hash) ([]*wire.BlockHeader, error) {
304+
305+
if len(locators) == 0 {
306+
return nil, errors.New("no locators provided")
309307
}
310308

311-
// Populate and return the found headers.
312-
headers := make([]wire.BlockHeader, 0, total)
313-
for i := uint32(0); i < total; i++ {
314-
header := wire.BlockHeader{
315-
Version: node.Version,
316-
PrevBlock: node.PreviousBlock,
317-
MerkleRoot: node.MerkleRoot,
318-
Timestamp: node.Timestamp,
319-
Bits: node.Bits,
320-
Nonce: node.Nonce,
321-
}
322-
headers = append(headers, header)
323-
node = hs.nodeByHeight(node.Height + 1)
309+
hashes := make([]string, len(locators))
310+
for i, v := range locators {
311+
hashes[i] = v.String()
324312
}
325-
return headers
326-
}
327313

328-
func (hs *HeaderService) locateInventory(locator domains.BlockLocator, hashStop *chainhash.Hash, maxEntries uint32) (*domains.BlockHeader, uint32) {
329-
// There are no block locators so a specific block is being requested
330-
// as identified by the stop hash.
331-
stopNode, _ := hs.GetHeaderByHash(hashStop.String())
332-
if len(locator) == 0 {
333-
if stopNode == nil {
334-
// No blocks with the stop hash were found so there is
335-
// nothing to do.
336-
return nil, 0
314+
startHeight, err := hs.repo.Headers.GetHeadersStartHeight(hashes)
315+
if err != nil {
316+
return nil, fmt.Errorf("error getting headers of locators: %v", err)
317+
}
318+
var stopHeight int
319+
if hashstop.IsEqual(&chainhash.Hash{}) {
320+
stopHeight = startHeight + wire.MaxCFHeadersPerMsg
321+
} else {
322+
stopHeight, err = hs.repo.Headers.GetHeadersStopHeight(hashstop.String())
323+
if err != nil {
324+
return nil, fmt.Errorf("error getting hashstop height: %v", err)
337325
}
338-
return stopNode, 1
339326
}
340327

341-
// Find the most recent locator block hash in the main chain. In the
342-
// case none of the hashes in the locator are in the main chain, fall
343-
// back to the genesis block.
344-
startNode, _ := hs.repo.Headers.GetHeaderByHeight(0)
345-
for _, hash := range locator {
346-
node, _ := hs.GetHeaderByHash(hash.String())
347-
if node != nil && hs.Contains(node) {
348-
startNode = node
349-
break
350-
}
328+
if stopHeight == 0 {
329+
stopHeight = startHeight + wire.MaxCFHeadersPerMsg
351330
}
352331

353-
// Start at the block after the most recently known block. When there
354-
// is no next block it means the most recently known block is the tip of
355-
// the best chain, so there is nothing more to do.
356-
next := hs.Next(startNode)
357-
if next == nil {
358-
return nil, 0
332+
if stopHeight <= startHeight {
333+
return nil, errors.New("hashStop is lower than first valid height")
359334
}
360-
startNode = next
361335

362-
// Calculate how many entries are needed.
363-
total := uint32((hs.GetTipHeight() - startNode.Height) + 1)
364-
if stopNode != nil && hs.Contains(stopNode) &&
365-
stopNode.Height >= startNode.Height {
336+
// Check if peer requested number of headers is higher than the maximum number of headers per message
337+
if wire.MaxCFHeadersPerMsg < stopHeight-startHeight {
338+
stopHeight = startHeight + wire.MaxCFHeadersPerMsg
339+
}
366340

367-
total = uint32((stopNode.Height - startNode.Height) + 1)
341+
dbHeaders, err := hs.repo.Headers.GetHeadersByHeightRange(startHeight+1, stopHeight)
342+
if err != nil {
343+
return nil, fmt.Errorf("error getting headers between heights: %v", err)
344+
}
345+
346+
headers := make([]*wire.BlockHeader, 0, len(dbHeaders))
347+
for _, dbHeader := range dbHeaders {
348+
header := &wire.BlockHeader{
349+
Version: dbHeader.Version,
350+
PrevBlock: dbHeader.PreviousBlock,
351+
MerkleRoot: dbHeader.MerkleRoot,
352+
Timestamp: dbHeader.Timestamp,
353+
Bits: dbHeader.Bits,
354+
Nonce: dbHeader.Nonce,
355+
}
356+
headers = append(headers, header)
368357
}
369-
if total > maxEntries {
370-
total = maxEntries
358+
359+
return headers, nil
360+
}
361+
362+
// LocateHeaders fetches headers for a number of blocks after the most recent known block
363+
// in the best chain, based on the provided block locator and stop hash, and defaults to the
364+
// genesis block if the locator is unknown.
365+
func (hs *HeaderService) LocateHeaders(locator domains.BlockLocator, hashStop *chainhash.Hash) []wire.BlockHeader {
366+
headers, err := hs.locateHeadersGetHeaders(locator, hashStop)
367+
if err != nil {
368+
hs.log.Error().Msg(err.Error())
369+
return nil
370+
}
371+
372+
result := make([]wire.BlockHeader, 0, len(headers))
373+
for _, header := range headers {
374+
result = append(result, *header)
371375
}
372376

373-
return startNode, total
377+
return result
374378
}
375379

376380
// Contains checks if given header is stored in db.

0 commit comments

Comments
 (0)