Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ type (

// pirated rpc "z_gettreestatelegacy" (legacy format for backward compatibility)
PiratedRpcReplyGettreestate struct {
Height int `json:"height"`
Hash string `json:"hash"`
Time uint32 `json:"time"`
Sprout struct {
Height int `json:"height"`
Hash string `json:"hash"`
Time uint32 `json:"time"`
Sprout struct {
SkipHash string `json:"skipHash,omitempty"`
Commitments struct {
FinalRoot string `json:"finalRoot"`
Expand All @@ -129,10 +129,10 @@ type (

// pirated rpc "z_gettreestate" (new format with bridge trees)
PiratedRpcReplyGetbridgetreestate struct {
Height int `json:"height"`
Hash string `json:"hash"`
Time uint32 `json:"time"`
Sprout struct {
Height int `json:"height"`
Hash string `json:"hash"`
Time uint32 `json:"time"`
Sprout struct {
Active bool `json:"active"`
SkipHash string `json:"skipHash,omitempty"`
Commitments struct {
Expand All @@ -158,6 +158,14 @@ type (
} `json:"orchard"`
}

// pirated rpc "z_getsubtreesbyindex"
PiratedRpcReplyGetsubtreesbyindex struct {
Index uint64 `json:"index"`
Root string `json:"root"`
CompletingBlockHash string `json:"completingBlockHash"`
CompletingBlockHeight uint64 `json:"completingBlockHeight"`
}

// pirated rpc "getrawtransaction txid 1" (1 means verbose), there are
PiratedRpcReplyGetrawtransaction struct {
Hex string
Expand Down
124 changes: 124 additions & 0 deletions frontend/getsubtreeroots_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) 2026 Pirate Chain developers
// Distributed under the MIT software license, see the accompanying
// file COPYING or https://www.opensource.org/licenses/mit-license.php .

package frontend

import (
"bytes"
"context"
"encoding/hex"
"encoding/json"
"testing"

"github.com/PirateNetwork/lightwalletd/common"
"github.com/PirateNetwork/lightwalletd/parser"
"github.com/PirateNetwork/lightwalletd/walletrpc"
"google.golang.org/grpc/metadata"
)

type subtreeRootsStream struct {
ctx context.Context
roots []*walletrpc.SubtreeRoot
}

func (s *subtreeRootsStream) SetHeader(metadata.MD) error { return nil }
func (s *subtreeRootsStream) SendHeader(metadata.MD) error { return nil }
func (s *subtreeRootsStream) SetTrailer(metadata.MD) {}
func (s *subtreeRootsStream) Context() context.Context { return s.ctx }
func (s *subtreeRootsStream) SendMsg(interface{}) error { return nil }
func (s *subtreeRootsStream) RecvMsg(interface{}) error { return nil }

func (s *subtreeRootsStream) Send(root *walletrpc.SubtreeRoot) error {
s.roots = append(s.roots, root)
return nil
}

func z_getsubtreesbyindexStub(method string, params []json.RawMessage) (json.RawMessage, error) {
if method != "z_getsubtreesbyindex" {
testT.Fatal("unexpected method in z_getsubtreesbyindexStub:", method)
}
if len(params) != 3 {
testT.Fatalf("unexpected params len in z_getsubtreesbyindexStub: %d", len(params))
}

var protocol string
if err := json.Unmarshal(params[0], &protocol); err != nil {
testT.Fatal("failed to parse protocol param:", err)
}
if protocol != "sapling" {
testT.Fatal("unexpected protocol param:", protocol)
}

var startIndex uint32
if err := json.Unmarshal(params[1], &startIndex); err != nil {
testT.Fatal("failed to parse startIndex param:", err)
}
if startIndex != 5 {
testT.Fatal("unexpected startIndex param:", startIndex)
}

var maxEntries uint32
if err := json.Unmarshal(params[2], &maxEntries); err != nil {
testT.Fatal("failed to parse maxEntries param:", err)
}
if maxEntries != 2 {
testT.Fatal("unexpected maxEntries param:", maxEntries)
}

mockResponse := `[
{
"index": 5,
"root": "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20",
"completingBlockHash": "00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff",
"completingBlockHeight": 12345
}
]`

return json.RawMessage(mockResponse), nil
}

func TestGetSubtreeRoots(t *testing.T) {
testT = t
common.RawRequest = z_getsubtreesbyindexStub

lwdInterface, err := NewLwdStreamer(nil, "main", false)
if err != nil {
t.Fatal("NewLwdStreamer failed:", err)
}
lwd := lwdInterface.(*lwdStreamer)

stream := &subtreeRootsStream{ctx: context.Background()}
err = lwd.GetSubtreeRoots(&walletrpc.GetSubtreeRootsArg{
StartIndex: 5,
ShieldedProtocol: walletrpc.ShieldedProtocol_sapling,
MaxEntries: 2,
}, stream)
if err != nil {
t.Fatal("GetSubtreeRoots failed:", err)
}

if len(stream.roots) != 1 {
t.Fatalf("unexpected subtree root count: %d", len(stream.roots))
}

expectedRoot, err := hex.DecodeString("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20")
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(stream.roots[0].RootHash, expectedRoot) {
t.Fatal("unexpected root hash bytes")
}

expectedBlockHash, err := hex.DecodeString("00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff")
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(stream.roots[0].CompletingBlockHash, parser.Reverse(expectedBlockHash)) {
t.Fatal("unexpected completing block hash bytes")
}

if stream.roots[0].CompletingBlockHeight != 12345 {
t.Fatal("unexpected completing block height:", stream.roots[0].CompletingBlockHeight)
}
}
94 changes: 82 additions & 12 deletions frontend/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,19 @@ func (s *lwdStreamer) GetCurrentARRRPrice(ctx context.Context, in *walletrpc.Emp
return resp, nil
}


// Returns the last block in a group of predefined total size
func (s *lwdStreamer) GetLiteWalletBlockGroup(ctx context.Context, id *walletrpc.BlockID) (*walletrpc.BlockID, error) {
latestBlock := s.cache.GetLatestHeight()

if latestBlock == -1 {
return nil, errors.New("Cache is empty. Server is probably not yet ready")
return nil, errors.New("Cache is empty. Server is probably not yet ready")
}

if int(id.Height) < 1 {
return nil, errors.New("Invalid block, must use height greater than 0")
if int(id.Height) < 1 {
return nil, errors.New("Invalid block, must use height greater than 0")
}

blockId := s.cache.GetLiteWalletBlockGroup(int(id.Height))
blockId := s.cache.GetLiteWalletBlockGroup(int(id.Height))
return blockId, nil
}

Expand Down Expand Up @@ -368,7 +367,7 @@ func (s *lwdStreamer) GetTreeState(ctx context.Context, id *walletrpc.BlockID) (
}
params[0] = hashJSON
}

// Prefer the legacy z_gettreestatelegacy RPC
result, rpcErr := common.RawRequest("z_gettreestatelegacy", params)
if rpcErr == nil {
Expand All @@ -387,7 +386,7 @@ func (s *lwdStreamer) GetTreeState(ctx context.Context, id *walletrpc.BlockID) (
SaplingTree: saplingTree,
OrchardTree: "", // Legacy format does not support Orchard
}, nil
}
}
}

// Fallback to newer z_gettreestate RPC
Expand Down Expand Up @@ -451,7 +450,7 @@ func (s *lwdStreamer) GetBridgeTreeState(ctx context.Context, id *walletrpc.Bloc
// z_gettreestatelegacy doesn't exist - return error for consistency
return nil, rpcErr
}

// Node supports bridge trees - get tree state from z_gettreestate
result, rpcErr := common.RawRequest("z_gettreestate", params)
if rpcErr != nil {
Expand All @@ -463,19 +462,19 @@ func (s *lwdStreamer) GetBridgeTreeState(ctx context.Context, id *walletrpc.Bloc
if err != nil {
return nil, err
}

// Use Sapling finalState if available, otherwise use finalRoot
saplingTree := gettreestateReply.Sapling.Commitments.FinalState
if saplingTree == "" {
saplingTree = gettreestateReply.Sapling.Commitments.FinalRoot
}
// Use Orchard finalState if available, otherwise use finalRoot

// Use Orchard finalState if available, otherwise use finalRoot
orchardTree := gettreestateReply.Orchard.Commitments.FinalState
if orchardTree == "" {
orchardTree = gettreestateReply.Orchard.Commitments.FinalRoot
}

return &walletrpc.TreeState{
Network: s.chainName,
Height: uint64(gettreestateReply.Height),
Expand All @@ -486,6 +485,77 @@ func (s *lwdStreamer) GetBridgeTreeState(ctx context.Context, id *walletrpc.Bloc
}, nil
}

func (s *lwdStreamer) GetSubtreeRoots(
arg *walletrpc.GetSubtreeRootsArg,
resp walletrpc.CompactTxStreamer_GetSubtreeRootsServer,
) error {
if arg == nil {
return errors.New("request for subtree roots is missing")
}

var protocol string
switch arg.ShieldedProtocol {
case walletrpc.ShieldedProtocol_sapling:
protocol = "sapling"
case walletrpc.ShieldedProtocol_orchard:
protocol = "orchard"
default:
return errors.New("unsupported shielded protocol")
}

params := make([]json.RawMessage, 3)

protocolJSON, err := json.Marshal(protocol)
if err != nil {
return err
}
params[0] = protocolJSON

startIndexJSON, err := json.Marshal(arg.StartIndex)
if err != nil {
return err
}
params[1] = startIndexJSON

maxEntriesJSON, err := json.Marshal(arg.MaxEntries)
if err != nil {
return err
}
params[2] = maxEntriesJSON

result, rpcErr := common.RawRequest("z_getsubtreesbyindex", params)
if rpcErr != nil {
return rpcErr
}

var subtreeRoots []common.PiratedRpcReplyGetsubtreesbyindex
if err := json.Unmarshal(result, &subtreeRoots); err != nil {
return err
}

for _, subtree := range subtreeRoots {
rootHash, err := hex.DecodeString(subtree.Root)
if err != nil {
return err
}

completingBlockHash, err := hex.DecodeString(subtree.CompletingBlockHash)
if err != nil {
return err
}

if err := resp.Send(&walletrpc.SubtreeRoot{
RootHash: rootHash,
CompletingBlockHash: parser.Reverse(completingBlockHash),
CompletingBlockHeight: subtree.CompletingBlockHeight,
}); err != nil {
return err
}
}

return nil
}

// GetTransaction returns the raw transaction bytes that are returned
// by the pirated 'getrawtransaction' RPC.
func (s *lwdStreamer) GetTransaction(ctx context.Context, txf *walletrpc.TxFilter) (*walletrpc.RawTransaction, error) {
Expand Down
Loading