Skip to content

Commit ae47e22

Browse files
committed
feat(x/meg): Support capturing components (#269)
* Use Matchable interface * Add Bytes to Matchable interface * feat(x/meg): Support capturing bytes * Export CaptureWithF Can be used by more specific capturers (e.g capture net.AddrIP) * Support Any match, RawValue, and multiple Concatenations * Add CaptureAddrPort
1 parent 09e5347 commit ae47e22

File tree

6 files changed

+282
-58
lines changed

6 files changed

+282
-58
lines changed

meg_capturers.go

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package multiaddr
2+
3+
import (
4+
"encoding/binary"
5+
"fmt"
6+
"net/netip"
7+
8+
"github.com/multiformats/go-multiaddr/x/meg"
9+
)
10+
11+
func CaptureAddrPort(network *string, ipPort *netip.AddrPort) (capturePattern meg.Pattern) {
12+
var ipOnly netip.Addr
13+
capturePort := func(s meg.Matchable) error {
14+
switch s.Code() {
15+
case P_UDP:
16+
*network = "udp"
17+
case P_TCP:
18+
*network = "tcp"
19+
default:
20+
return fmt.Errorf("invalid network: %s", s.Value())
21+
}
22+
23+
port := binary.BigEndian.Uint16(s.RawValue())
24+
*ipPort = netip.AddrPortFrom(ipOnly, port)
25+
return nil
26+
}
27+
28+
pattern := meg.Cat(
29+
meg.Or(
30+
meg.CaptureWithF(P_IP4, func(s meg.Matchable) error {
31+
var ok bool
32+
ipOnly, ok = netip.AddrFromSlice(s.RawValue())
33+
if !ok {
34+
return fmt.Errorf("invalid ip4 address: %s", s.Value())
35+
}
36+
return nil
37+
}),
38+
meg.CaptureWithF(P_IP6, func(s meg.Matchable) error {
39+
var ok bool
40+
ipOnly, ok = netip.AddrFromSlice(s.RawValue())
41+
if !ok {
42+
return fmt.Errorf("invalid ip6 address: %s", s.Value())
43+
}
44+
return nil
45+
}),
46+
),
47+
meg.Or(
48+
meg.CaptureWithF(P_UDP, capturePort),
49+
meg.CaptureWithF(P_TCP, capturePort),
50+
),
51+
)
52+
53+
return pattern
54+
}

meg_test.go

+29-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package multiaddr
22

33
import (
4+
"net/netip"
45
"testing"
56

67
"github.com/multiformats/go-multiaddr/x/meg"
@@ -16,10 +17,10 @@ func TestMatchAndCaptureMultiaddr(t *testing.T) {
1617
meg.Val(P_IP4),
1718
meg.Val(P_IP6),
1819
),
19-
meg.CaptureVal(P_UDP, &udpPort),
20+
meg.CaptureStringVal(P_UDP, &udpPort),
2021
meg.Val(P_QUIC_V1),
2122
meg.Val(P_WEBTRANSPORT),
22-
meg.CaptureZeroOrMore(P_CERTHASH, &certhashes),
23+
meg.CaptureZeroOrMoreStringVals(P_CERTHASH, &certhashes),
2324
)
2425
if !found {
2526
t.Fatal("failed to match")
@@ -43,3 +44,29 @@ func TestMatchAndCaptureMultiaddr(t *testing.T) {
4344
}
4445
}
4546
}
47+
48+
func TestCaptureAddrPort(t *testing.T) {
49+
m := StringCast("/ip4/1.2.3.4/udp/8231/quic-v1/webtransport")
50+
var addrPort netip.AddrPort
51+
var network string
52+
53+
found, err := m.Match(
54+
CaptureAddrPort(&network, &addrPort),
55+
meg.ZeroOrMore(meg.Any),
56+
)
57+
if err != nil {
58+
t.Fatal("error", err)
59+
}
60+
if !found {
61+
t.Fatal("failed to match")
62+
}
63+
if !addrPort.IsValid() {
64+
t.Fatal("failed to capture addrPort")
65+
}
66+
if network != "udp" {
67+
t.Fatal("unexpected network", network)
68+
}
69+
if addrPort.String() != "1.2.3.4:8231" {
70+
t.Fatal("unexpected ipPort", addrPort)
71+
}
72+
}

x/meg/bench_test.go

+58-17
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func preallocateCapture() *preallocatedCapture {
2222
),
2323
meg.Val(multiaddr.P_UDP),
2424
meg.Val(multiaddr.P_WEBRTC_DIRECT),
25-
meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes),
25+
meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &p.certHashes),
2626
)
2727
return p
2828
}
@@ -87,19 +87,19 @@ func isWebTransportMultiaddrPrealloc() *preallocatedCapture {
8787
var sni string
8888
p.matcher = meg.PatternToMatcher(
8989
meg.Or(
90-
meg.CaptureVal(multiaddr.P_IP4, &ip4Addr),
91-
meg.CaptureVal(multiaddr.P_IP6, &ip6Addr),
92-
meg.CaptureVal(multiaddr.P_DNS4, &dnsName),
93-
meg.CaptureVal(multiaddr.P_DNS6, &dnsName),
94-
meg.CaptureVal(multiaddr.P_DNS, &dnsName),
90+
meg.CaptureStringVal(multiaddr.P_IP4, &ip4Addr),
91+
meg.CaptureStringVal(multiaddr.P_IP6, &ip6Addr),
92+
meg.CaptureStringVal(multiaddr.P_DNS4, &dnsName),
93+
meg.CaptureStringVal(multiaddr.P_DNS6, &dnsName),
94+
meg.CaptureStringVal(multiaddr.P_DNS, &dnsName),
9595
),
96-
meg.CaptureVal(multiaddr.P_UDP, &udpPort),
96+
meg.CaptureStringVal(multiaddr.P_UDP, &udpPort),
9797
meg.Val(multiaddr.P_QUIC_V1),
9898
meg.Optional(
99-
meg.CaptureVal(multiaddr.P_SNI, &sni),
99+
meg.CaptureStringVal(multiaddr.P_SNI, &sni),
100100
),
101101
meg.Val(multiaddr.P_WEBTRANSPORT),
102-
meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes),
102+
meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &p.certHashes),
103103
)
104104
wtPrealloc = p
105105
return p
@@ -120,26 +120,55 @@ func IsWebTransportMultiaddr(m multiaddr.Multiaddr) (bool, int) {
120120
var certHashesStr []string
121121
matched, _ := m.Match(
122122
meg.Or(
123-
meg.CaptureVal(multiaddr.P_IP4, &ip4Addr),
124-
meg.CaptureVal(multiaddr.P_IP6, &ip6Addr),
125-
meg.CaptureVal(multiaddr.P_DNS4, &dnsName),
126-
meg.CaptureVal(multiaddr.P_DNS6, &dnsName),
127-
meg.CaptureVal(multiaddr.P_DNS, &dnsName),
123+
meg.CaptureStringVal(multiaddr.P_IP4, &ip4Addr),
124+
meg.CaptureStringVal(multiaddr.P_IP6, &ip6Addr),
125+
meg.CaptureStringVal(multiaddr.P_DNS4, &dnsName),
126+
meg.CaptureStringVal(multiaddr.P_DNS6, &dnsName),
127+
meg.CaptureStringVal(multiaddr.P_DNS, &dnsName),
128128
),
129-
meg.CaptureVal(multiaddr.P_UDP, &udpPort),
129+
meg.CaptureStringVal(multiaddr.P_UDP, &udpPort),
130130
meg.Val(multiaddr.P_QUIC_V1),
131131
meg.Optional(
132-
meg.CaptureVal(multiaddr.P_SNI, &sni),
132+
meg.CaptureStringVal(multiaddr.P_SNI, &sni),
133133
),
134134
meg.Val(multiaddr.P_WEBTRANSPORT),
135-
meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &certHashesStr),
135+
meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &certHashesStr),
136136
)
137137
if !matched {
138138
return false, 0
139139
}
140140
return true, len(certHashesStr)
141141
}
142142

143+
func IsWebTransportMultiaddrCaptureBytes(m multiaddr.Multiaddr) (bool, int) {
144+
var dnsName []byte
145+
var ip4Addr []byte
146+
var ip6Addr []byte
147+
var udpPort []byte
148+
var sni []byte
149+
var certHashes [][]byte
150+
matched, _ := m.Match(
151+
meg.Or(
152+
meg.CaptureBytes(multiaddr.P_IP4, &ip4Addr),
153+
meg.CaptureBytes(multiaddr.P_IP6, &ip6Addr),
154+
meg.CaptureBytes(multiaddr.P_DNS4, &dnsName),
155+
meg.CaptureBytes(multiaddr.P_DNS6, &dnsName),
156+
meg.CaptureBytes(multiaddr.P_DNS, &dnsName),
157+
),
158+
meg.CaptureBytes(multiaddr.P_UDP, &udpPort),
159+
meg.Val(multiaddr.P_QUIC_V1),
160+
meg.Optional(
161+
meg.CaptureBytes(multiaddr.P_SNI, &sni),
162+
),
163+
meg.Val(multiaddr.P_WEBTRANSPORT),
164+
meg.CaptureZeroOrMoreBytes(multiaddr.P_CERTHASH, &certHashes),
165+
)
166+
if !matched {
167+
return false, 0
168+
}
169+
return true, len(certHashes)
170+
}
171+
143172
func IsWebTransportMultiaddrNoCapture(m multiaddr.Multiaddr) (bool, int) {
144173
matched, _ := m.Match(
145174
meg.Or(
@@ -355,6 +384,18 @@ func BenchmarkIsWebTransportMultiaddrNoCapture(b *testing.B) {
355384
}
356385
}
357386

387+
func BenchmarkIsWebTransportMultiaddrCaptureBytes(b *testing.B) {
388+
addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport")
389+
390+
b.ResetTimer()
391+
for i := 0; i < b.N; i++ {
392+
isWT, count := IsWebTransportMultiaddrCaptureBytes(addr)
393+
if !isWT || count != 0 {
394+
b.Fatal("unexpected result")
395+
}
396+
}
397+
}
398+
358399
func BenchmarkIsWebTransportMultiaddr(b *testing.B) {
359400
addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport")
360401

x/meg/meg.go

+26-14
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,30 @@ import (
1212
type stateKind = int
1313

1414
const (
15-
done stateKind = (iota * -1) - 1
16-
// split anything else that is negative
15+
matchAny stateKind = (iota * -1) - 1
16+
// done MUST be the last stateKind in this list. We use it to determine if a
17+
// state is a split index.
18+
done
19+
// Anything that is less than done is a split index
1720
)
1821

1922
// MatchState is the Thompson NFA for a regular expression.
2023
type MatchState struct {
21-
capture captureFunc
24+
capture CaptureFunc
2225
// next is is the index of the next state. in the MatchState array.
2326
next int
2427
// If codeOrKind is negative, it is a kind.
25-
// If it is negative, but not a `done`, then it is the index to the next split.
28+
// If it is negative, and less than `done`, then it is the index to the next split.
2629
// This is done to keep the `MatchState` struct small and cache friendly.
2730
codeOrKind int
2831
}
2932

30-
type captureFunc func(string) error
33+
type CaptureFunc func(Matchable) error
3134

3235
// capture is a linked list of capture funcs with values.
3336
type capture struct {
34-
f captureFunc
35-
v string
37+
f CaptureFunc
38+
v Matchable
3639
prev *capture
3740
}
3841

@@ -53,7 +56,14 @@ func (s MatchState) String() string {
5356

5457
type Matchable interface {
5558
Code() int
56-
Value() string // Used when capturing the value
59+
// Value() returns the string representation of the matchable.
60+
Value() string
61+
// RawValue() returns the byte representation of the Value
62+
RawValue() []byte
63+
// Bytes() returns the underlying bytes of the matchable. For multiaddr
64+
// Components, this includes the protocol code and possibly the varint
65+
// encoded size.
66+
Bytes() []byte
5767
}
5868

5969
// Match returns whether the given Components match the Pattern defined in MatchState.
@@ -89,12 +99,12 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) {
8999
}
90100
for i, stateIndex := range currentStates.states {
91101
s := states[stateIndex]
92-
if s.codeOrKind >= 0 && s.codeOrKind == c.Code() {
102+
if s.codeOrKind == matchAny || (s.codeOrKind >= 0 && s.codeOrKind == c.Code()) {
93103
cm := currentStates.captures[i]
94104
if s.capture != nil {
95105
next := &capture{
96106
f: s.capture,
97-
v: c.Value(),
107+
v: c,
98108
}
99109
if cm == nil {
100110
cm = next
@@ -122,8 +132,8 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) {
122132
// Flip the order of the captures because we see captures from right
123133
// to left, but users expect them left to right.
124134
type captureWithVal struct {
125-
f captureFunc
126-
v string
135+
f CaptureFunc
136+
v Matchable
127137
}
128138
reversedCaptures := make([]captureWithVal, 0, 16)
129139
for c != nil {
@@ -190,10 +200,12 @@ func appendState(arr statesAndCaptures, states []MatchState, stateIndex int, c *
190200
return arr
191201
}
192202

203+
const splitIdxOffset = (-1 * (done - 1))
204+
193205
func storeSplitIdx(codeOrKind int) int {
194-
return (codeOrKind + 2) * -1
206+
return (codeOrKind + splitIdxOffset) * -1
195207
}
196208

197209
func restoreSplitIdx(splitIdx int) int {
198-
return (splitIdx * -1) - 2
210+
return (splitIdx * -1) - splitIdxOffset
199211
}

x/meg/meg_test.go

+28-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ func (c codeAndValue) Value() string {
2222
return c.val
2323
}
2424

25+
// Bytes implements Matchable.
26+
func (c codeAndValue) Bytes() []byte {
27+
return []byte(c.val)
28+
}
29+
30+
// RawValue implements Matchable.
31+
func (c codeAndValue) RawValue() []byte {
32+
return []byte(c.val)
33+
}
34+
2535
var _ Matchable = codeAndValue{}
2636

2737
func TestSimple(t *testing.T) {
@@ -33,6 +43,22 @@ func TestSimple(t *testing.T) {
3343
}
3444
testCases :=
3545
[]testCase{
46+
{
47+
pattern: PatternToMatcher(Val(Any), Val(1)),
48+
shouldMatch: [][]int{
49+
{0, 1},
50+
{1, 1},
51+
{2, 1},
52+
{3, 1},
53+
{4, 1},
54+
},
55+
shouldNotMatch: [][]int{
56+
{0},
57+
{0, 0},
58+
{0, 1, 0},
59+
},
60+
skipQuickCheck: true,
61+
},
3662
{
3763
pattern: PatternToMatcher(Val(0), Val(1)),
3864
shouldMatch: [][]int{{0, 1}},
@@ -119,7 +145,7 @@ func TestCapture(t *testing.T) {
119145
{
120146
setup: func() (Matcher, func()) {
121147
var code0str string
122-
return PatternToMatcher(CaptureVal(0, &code0str), Val(1)), func() {
148+
return PatternToMatcher(CaptureStringVal(0, &code0str), Val(1)), func() {
123149
if code0str != "hello" {
124150
panic("unexpected value")
125151
}
@@ -130,7 +156,7 @@ func TestCapture(t *testing.T) {
130156
{
131157
setup: func() (Matcher, func()) {
132158
var code0strs []string
133-
return PatternToMatcher(CaptureOneOrMore(0, &code0strs), Val(1)), func() {
159+
return PatternToMatcher(CaptureOneOrMoreStringVals(0, &code0strs), Val(1)), func() {
134160
if code0strs[0] != "hello" {
135161
panic("unexpected value")
136162
}

0 commit comments

Comments
 (0)