diff --git a/cmd/thor/flags.go b/cmd/thor/flags.go index 3a5fac58b..7b47380b0 100644 --- a/cmd/thor/flags.go +++ b/cmd/thor/flags.go @@ -79,11 +79,15 @@ var ( Value: "any", Usage: "port mapping mechanism (any|none|upnp|pmp|extip:)", } - bootNodeFlag = cli.StringFlag{ Name: "bootnode", Usage: "comma separated list of bootnode IDs", } + allowedPeersFlag = cli.StringFlag{ + Name: "allowed-peers", + Hidden: true, + Usage: "comma separated list of node IDs that can be connected to", + } importMasterKeyFlag = cli.BoolFlag{ Name: "import", Usage: "import master key from keystore", diff --git a/cmd/thor/main.go b/cmd/thor/main.go index b08d3fda2..5e3c054db 100644 --- a/cmd/thor/main.go +++ b/cmd/thor/main.go @@ -83,6 +83,7 @@ func main() { p2pPortFlag, natFlag, bootNodeFlag, + allowedPeersFlag, skipLogsFlag, pprofFlag, verifyLogsFlag, diff --git a/cmd/thor/utils.go b/cmd/thor/utils.go index 9b81d0d65..0250a0ea9 100644 --- a/cmd/thor/utils.go +++ b/cmd/thor/utils.go @@ -424,14 +424,19 @@ type p2pComm struct { } func newP2PComm(ctx *cli.Context, repo *chain.Repository, txPool *txpool.TxPool, instanceDir string) (*p2pComm, error) { + // known peers will be loaded/stored from/in this file + peersCachePath := filepath.Join(instanceDir, "peers.cache") + configDir, err := makeConfigDir(ctx) if err != nil { return nil, err } + key, err := loadOrGeneratePrivateKey(filepath.Join(configDir, "p2p.key")) if err != nil { return nil, errors.Wrap(err, "load or generate P2P key") } + nat, err := nat.Parse(ctx.String(natFlag.Name)) if err != nil { cli.ShowAppHelp(ctx) @@ -448,30 +453,46 @@ func newP2PComm(ctx *cli.Context, repo *chain.Repository, txPool *txpool.TxPool, NAT: nat, } - peersCachePath := filepath.Join(instanceDir, "peers.cache") - - if data, err := os.ReadFile(peersCachePath); err != nil { - if !os.IsNotExist(err) { + // allowed peers flag will only allow p2psrv to connect to the designated peers + flagAllowedPeers := strings.TrimSpace(ctx.String(allowedPeersFlag.Name)) + if flagAllowedPeers != "" { + opts.NoDiscovery = true // disable discovery + opts.KnownNodes, err = parseNodeList(flagAllowedPeers) + if err != nil { + return nil, errors.Wrap(err, "parse allowed-peers flag") + } + } else { + var knownNodes p2psrv.Nodes + if data, err := os.ReadFile(peersCachePath); err != nil { + if !os.IsNotExist(err) { + log.Warn("failed to load peers cache", "err", err) + } + } else if err := rlp.DecodeBytes(data, &knownNodes); err != nil { log.Warn("failed to load peers cache", "err", err) } - } else if err := rlp.DecodeBytes(data, &opts.KnownNodes); err != nil { - log.Warn("failed to load peers cache", "err", err) - } - flagBootstrapNodes := parseBootNode(ctx) - if flagBootstrapNodes != nil { - opts.BootstrapNodes = flagBootstrapNodes - opts.RemoteBootstrap = "" + // boot nodes flag will overwrite the default bootstrap nodes and also disable remote bootstrap + flagBootstrapNodes := strings.TrimSpace(ctx.String(bootNodeFlag.Name)) + if flagBootstrapNodes != "" { + opts.RemoteBootstrap = "" // disable remote bootstrap + opts.BootstrapNodes, err = parseNodeList(flagBootstrapNodes) + if err != nil { + return nil, errors.Wrap(err, "parse bootnodes flag") + } - m := make(map[discover.NodeID]bool) - for _, node := range opts.KnownNodes { - m[node.ID] = true - } - for _, bootnode := range flagBootstrapNodes { - if !m[bootnode.ID] { - opts.KnownNodes = append(opts.KnownNodes, bootnode) + m := make(map[discover.NodeID]bool) + for _, node := range knownNodes { + m[node.ID] = true + } + //appending user supplied boot nodes to known nodes since they potentially could be a p2p server + for _, bootnode := range opts.BootstrapNodes { + if !m[bootnode.ID] { + knownNodes = append(opts.KnownNodes, bootnode) + } } } + + opts.KnownNodes = knownNodes } return &p2pComm{ @@ -621,16 +642,18 @@ func printSoloStartupMessage( fmt.Print(info) } -func parseBootNode(ctx *cli.Context) []*discover.Node { - s := strings.TrimSpace(ctx.String(bootNodeFlag.Name)) - if s == "" { - return nil - } - inputs := strings.Split(s, ",") +func parseNodeList(list string) ([]*discover.Node, error) { + inputs := strings.Split(list, ",") var nodes []*discover.Node for _, i := range inputs { - node := discover.MustParseNode(i) + node, err := discover.ParseNode(i) + if err != nil { + return nil, err + } nodes = append(nodes, node) } - return nodes + if len(nodes) == 0 { + return nil, errors.New("empty node list") + } + return nodes, nil } diff --git a/p2psrv/server.go b/p2psrv/server.go index 609516b41..f4a455e71 100644 --- a/p2psrv/server.go +++ b/p2psrv/server.go @@ -54,8 +54,8 @@ func New(opts *Options) *Server { Name: opts.Name, PrivateKey: opts.PrivateKey, MaxPeers: opts.MaxPeers, - NoDiscovery: true, - DiscoveryV5: false, // disable discovery inside p2p.Server instance + NoDiscovery: true, // disable discovery inside p2p.Server instance(we use our own) + DiscoveryV5: false, // disable discovery inside p2p.Server instance(we use our own) ListenAddr: opts.ListenAddr, NetRestrict: opts.NetRestrict, NAT: opts.NAT, @@ -200,6 +200,7 @@ func (s *Server) listenDiscV5() (err error) { for _, node := range s.opts.BootstrapNodes { s.bootstrapNodes = append(s.bootstrapNodes, discv5.NewNode(discv5.NodeID(node.ID), node.IP, node.UDP, node.TCP)) } + // known nodes are also acting as bootstrap servers for _, node := range s.opts.KnownNodes { s.bootstrapNodes = append(s.bootstrapNodes, discv5.NewNode(discv5.NodeID(node.ID), node.IP, node.UDP, node.TCP)) } @@ -304,6 +305,7 @@ func (s *Server) dialLoop() { s.dialingNodes.Remove(node.ID) log.Debug("failed to dial node", "err", err) } + s.discoveredNodes.Remove(node.ID) }() dialCount++ diff --git a/p2psrv/server_test.go b/p2psrv/server_test.go index 1aed6e5ad..5cdc45d12 100644 --- a/p2psrv/server_test.go +++ b/p2psrv/server_test.go @@ -3,4 +3,75 @@ // Distributed under the GNU Lesser General Public License v3.0 software license, see the accompanying // file LICENSE or -package p2psrv_test +package p2psrv + +import ( + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/stretchr/testify/assert" +) + +func TestNewServer(t *testing.T) { + privateKey, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Unable to generate private key: %v", err) + } + + node := discover.MustParseNode("enode://1234cf28ab5f0255a3923ac094d0168ce884a9fa5f3998b1844986b4a2b1eac52fcccd8f2916be9b8b0f7798147ee5592ec3c83518925fac50f812577515d6ad@10.3.58.6:30303?discport=30301") + opts := &Options{ + Name: "testNode", + PrivateKey: privateKey, + MaxPeers: 10, + ListenAddr: ":30303", + NetRestrict: nil, + NAT: nil, + NoDial: false, + KnownNodes: Nodes{node}, + } + + server := New(opts) + + assert.Equal(t, "testNode", server.opts.Name) + assert.Equal(t, privateKey, server.opts.PrivateKey) + assert.Equal(t, 10, server.opts.MaxPeers) + assert.Equal(t, ":30303", server.opts.ListenAddr) + assert.Equal(t, server.discoveredNodes.Len(), 1) + assert.Equal(t, server.knownNodes.Len(), 1) + assert.True(t, server.discoveredNodes.Contains(node.ID)) + assert.True(t, server.knownNodes.Contains(node.ID)) + assert.False(t, server.opts.NoDial) +} + +func TestNewServerConnectOnly(t *testing.T) { + privateKey, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Unable to generate private key: %v", err) + } + + knownNode := discover.MustParseNode("enode://1234cf28ab5f0255a3923ac094d0168ce884a9fa5f3998b1844986b4a2b1eac52fcccd8f2916be9b8b0f7798147ee5592ec3c83518925fac50f812577515d6ad@10.3.58.6:30303?discport=30301") + opts := &Options{ + Name: "testNode", + PrivateKey: privateKey, + MaxPeers: 10, + ListenAddr: ":30303", + NetRestrict: nil, + NAT: nil, + NoDial: false, + KnownNodes: Nodes{knownNode}, + } + + server := New(opts) + + assert.Equal(t, "testNode", server.opts.Name) + assert.Equal(t, privateKey, server.opts.PrivateKey) + assert.Equal(t, 10, server.opts.MaxPeers) + assert.Equal(t, ":30303", server.opts.ListenAddr) + assert.False(t, server.opts.NoDial) + + assert.Equal(t, server.discoveredNodes.Len(), 1) + assert.Equal(t, server.knownNodes.Len(), 1) + assert.True(t, server.discoveredNodes.Contains(knownNode.ID)) + assert.True(t, server.knownNodes.Contains(knownNode.ID)) +}