diff --git a/consensus/XDPoS/utils/pool.go b/consensus/XDPoS/utils/pool.go index 9c09a8ba53bd..869459f8dc4e 100644 --- a/consensus/XDPoS/utils/pool.go +++ b/consensus/XDPoS/utils/pool.go @@ -1,9 +1,11 @@ package utils import ( + "encoding/json" "sync" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/log" ) type PoolObj interface { @@ -21,8 +23,11 @@ func NewPool() *Pool { objList: make(map[string]map[common.Hash]PoolObj), } } + func (p *Pool) Get() map[string]map[common.Hash]PoolObj { - return p.objList + p.lock.RLock() + defer p.lock.RUnlock() + return p.getSnapshot() } func (p *Pool) Add(obj PoolObj) (int, map[common.Hash]PoolObj) { @@ -36,10 +41,13 @@ func (p *Pool) Add(obj PoolObj) (int, map[common.Hash]PoolObj) { } objListKeyed[obj.Hash()] = obj numOfItems := len(objListKeyed) - return numOfItems, objListKeyed + safeCopy := p.getSafePoolObjMap(objListKeyed) + return numOfItems, safeCopy } func (p *Pool) Size(obj PoolObj) int { + p.lock.Lock() + defer p.lock.Unlock() poolKey := obj.PoolKey() objListKeyed, ok := p.objList[poolKey] if !ok { @@ -84,8 +92,8 @@ func (p *Pool) Clear() { } func (p *Pool) GetObjsByKey(poolKey string) []PoolObj { - p.lock.Lock() - defer p.lock.Unlock() + p.lock.RLock() + defer p.lock.RUnlock() objListKeyed, ok := p.objList[poolKey] if !ok { @@ -94,8 +102,68 @@ func (p *Pool) GetObjsByKey(poolKey string) []PoolObj { objList := make([]PoolObj, len(objListKeyed)) cnt := 0 for _, obj := range objListKeyed { - objList[cnt] = obj - cnt += 1 + objList[cnt] = p.getSafePoolObj(obj) + cnt++ } return objList } + +// caller should hold lock +func (p *Pool) getSnapshot() map[string]map[common.Hash]PoolObj { + data, err := json.Marshal(p.objList) + if err != nil { + // This should never happen + log.Error("[getSafeCopy] Error while marshalling pool object list", "error", err) + return make(map[string]map[common.Hash]PoolObj) + } + + var dataCopy map[string]map[common.Hash]PoolObj + err = json.Unmarshal(data, &dataCopy) + if err != nil { + // This should never happen + log.Error("[getSafeCopy] Error while unmarshalling pool object list", "error", err) + return make(map[string]map[common.Hash]PoolObj) + } + + return dataCopy +} + +// caller should hold lock +func (p *Pool) getSafePoolObjMap(objMap map[common.Hash]PoolObj) map[common.Hash]PoolObj { + data, err := json.Marshal(objMap) + if err != nil { + // This should never happen + log.Error("[getSafeCopy] Error while marshalling pool object list", "error", err) + return make(map[common.Hash]PoolObj) + } + + var dataCopy map[common.Hash]PoolObj + err = json.Unmarshal(data, &dataCopy) + if err != nil { + // This should never happen + log.Error("[getSafeCopy] Error while unmarshalling pool object list", "error", err) + return make(map[common.Hash]PoolObj) + } + + return dataCopy +} + +// caller should hold lock +func (p *Pool) getSafePoolObj(obj PoolObj) PoolObj { + data, err := json.Marshal(obj) + if err != nil { + // This should never happen + log.Error("[getSafeCopy] Error while marshalling pool object list", "error", err) + return nil + } + + var dataCopy PoolObj + err = json.Unmarshal(data, &dataCopy) + if err != nil { + // This should never happen + log.Error("[getSafeCopy] Error while unmarshalling pool object list", "error", err) + return nil + } + + return dataCopy +}