Skip to content

Commit

Permalink
make extract multithreaded. (#72)
Browse files Browse the repository at this point in the history
* change tasks of byte ranges from slice to linked list

* default download threads to 4; fix tests [#68]

* return errors from download threads
  • Loading branch information
bdon authored Sep 11, 2023
1 parent 35e9288 commit d935662
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 51 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/schollz/progressbar/v3 v3.11.0
github.com/stretchr/testify v1.8.1
gocloud.dev v0.34.0
golang.org/x/sync v0.3.0
zombiezen.com/go/sqlite v0.10.1
)

Expand Down Expand Up @@ -71,7 +72,6 @@ require (
golang.org/x/crypto v0.11.0 // indirect
golang.org/x/net v0.13.0 // indirect
golang.org/x/oauth2 v0.10.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.10.0 // indirect
golang.org/x/term v0.10.0 // indirect
golang.org/x/text v0.11.0 // indirect
Expand Down
3 changes: 2 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var cli struct {
Bucket string `help:"Remote bucket of input archive."`
Region string `help:"local GeoJSON Polygon or MultiPolygon file for area of interest." type:"existingfile"`
Maxzoom uint8 `help:"Maximum zoom level, inclusive."`
DownloadThreads int `default:4 help:"Number of download threads."`
DryRun bool `help:"Calculate tiles to extract, but don't download them."`
Overfetch float32 `default:0.1 help:"What ratio of extra data to download to minimize # requests; 0.2 is 20%"`
} `cmd:"" help:"Create an archive from a larger archive for a subset of zoom levels or geographic region."`
Expand Down Expand Up @@ -120,7 +121,7 @@ func main() {
logger.Printf("Serving %s %s on port %d with Access-Control-Allow-Origin: %s\n", cli.Serve.Bucket, cli.Serve.Path, cli.Serve.Port, cli.Serve.Cors)
logger.Fatal(http.ListenAndServe(":"+strconv.Itoa(cli.Serve.Port), nil))
case "extract <input> <output>":
err := pmtiles.Extract(logger, cli.Extract.Bucket, cli.Extract.Input, cli.Extract.Maxzoom, cli.Extract.Region, cli.Extract.Output, cli.Extract.Overfetch, cli.Extract.DryRun)
err := pmtiles.Extract(logger, cli.Extract.Bucket, cli.Extract.Input, cli.Extract.Maxzoom, cli.Extract.Region, cli.Extract.Output, cli.Extract.DownloadThreads, cli.Extract.Overfetch, cli.Extract.DryRun)
if err != nil {
logger.Fatalf("Failed to extract, %v", err)
}
Expand Down
123 changes: 85 additions & 38 deletions pmtiles/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@ package pmtiles
import (
"bytes"
"context"
"container/list"
"fmt"
"github.com/RoaringBitmap/roaring/roaring64"
"github.com/dustin/go-humanize"
"github.com/paulmach/orb"
"github.com/paulmach/orb/geojson"
"github.com/schollz/progressbar/v3"
"gocloud.dev/blob"
"golang.org/x/sync/errgroup"
"io"
"io/ioutil"
"log"
"math"
"os"
"sort"
"strings"
"sync"
"time"
)

Expand Down Expand Up @@ -157,10 +160,10 @@ type OverfetchListItem struct {
// input ranges are merged in order of smallest byte distance to next range
// until the overfetch budget is consumed.
// The slice is sorted by Length
func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange {
func MergeRanges(ranges []SrcDstRange, overfetch float32) (*list.List, uint64) {
total_size := 0

list := make([]*OverfetchListItem, len(ranges))
shortest := make([]*OverfetchListItem, len(ranges))

// create the heap items
for i, rng := range ranges {
Expand All @@ -174,7 +177,7 @@ func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange {
}
}

list[i] = &OverfetchListItem{
shortest[i] = &OverfetchListItem{
Rng: rng,
BytesToNext: bytes_to_next,
CopyDiscards: []CopyDiscard{{uint64(rng.Length), 0}},
Expand All @@ -183,21 +186,18 @@ func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange {
}

// make the list doubly-linked
for i, item := range list {
for i, item := range shortest {
if i > 0 {
item.prev = list[i-1]
item.prev = shortest[i-1]
}
if i < len(list)-1 {
item.next = list[i+1]
if i < len(shortest)-1 {
item.next = shortest[i+1]
}
}

overfetch_budget := int(float32(total_size) * overfetch)

// create a 2nd slice, sorted by ascending distance to next range
shortest := make([]*OverfetchListItem, len(list))
copy(shortest, list)

// sort by ascending distance to next range
sort.Slice(shortest, func(i, j int) bool {
return shortest[i].BytesToNext < shortest[j].BytesToNext
})
Expand All @@ -221,21 +221,21 @@ func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange {
overfetch_budget -= int(item.BytesToNext)
}

// copy out the result structs
result := make([]OverfetchRange, len(shortest))

sort.Slice(shortest, func(i, j int) bool {
return shortest[i].Rng.DstOffset < shortest[j].Rng.DstOffset
return shortest[i].Rng.Length > shortest[j].Rng.Length
})

for i, x := range shortest {
result[i] = OverfetchRange{
total_bytes := uint64(0)
result := list.New()
for _, x := range shortest {
result.PushBack(OverfetchRange{
Rng: x.Rng,
CopyDiscards: x.CopyDiscards,
}
})
total_bytes += x.Rng.Length
}

return result
return result, total_bytes
}

// 1. Get the root directory (check that it is clustered)
Expand All @@ -252,7 +252,7 @@ func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange {
// 10. write the leaf directories (if any)
// 11. Get all tiles, and write directly to the output.

func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, region_file string, output string, overfetch float32, dry_run bool) error {
func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, region_file string, output string, download_threads int, overfetch float32, dry_run bool) error {
// 1. fetch the header

if bucketURL == "" {
Expand Down Expand Up @@ -342,10 +342,15 @@ func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, r
leaf_ranges = append(leaf_ranges, SrcDstRange{header.LeafDirectoryOffset + leaf.Offset, 0, uint64(leaf.Length)})
}

overfetch_leaves := MergeRanges(leaf_ranges, overfetch)
fmt.Printf("fetching %d dirs, %d chunks, %d requests\n", len(leaves), len(leaf_ranges), len(overfetch_leaves))
overfetch_leaves, _ := MergeRanges(leaf_ranges, overfetch)
num_overfetch_leaves := overfetch_leaves.Len()
fmt.Printf("fetching %d dirs, %d chunks, %d requests\n", len(leaves), len(leaf_ranges), overfetch_leaves.Len())

for _, or := range overfetch_leaves {
for {
if overfetch_leaves.Len() == 0 {
break
}
or := overfetch_leaves.Remove(overfetch_leaves.Front()).(OverfetchRange)

slab_r, err := bucket.NewRangeReader(ctx, file, int64(or.Rng.SrcOffset), int64(or.Rng.Length), nil)
if err != nil {
Expand Down Expand Up @@ -385,9 +390,12 @@ func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, r
// we now need to re-encode this entry list using cumulative offsets
reencoded, tile_parts, tiledata_length, addressed_tiles, tile_contents := ReencodeEntries(tile_entries)

overfetch_ranges := MergeRanges(tile_parts, overfetch)
fmt.Printf("fetching %d tiles, %d chunks, %d requests\n", len(reencoded), len(tile_parts), len(overfetch_ranges))
overfetch_ranges, total_bytes := MergeRanges(tile_parts, overfetch)

num_overfetch_ranges := overfetch_ranges.Len()
fmt.Printf("fetching %d tiles, %d chunks, %d requests\n", len(reencoded), len(tile_parts), overfetch_ranges.Len())

// TODO: takes up too much RAM
// construct the directories
new_root_bytes, new_leaves_bytes, _ := optimize_directories(reencoded, 16384-HEADERV3_LEN_BYTES)

Expand All @@ -414,15 +422,18 @@ func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, r

header_bytes := serialize_header(header)

total_bytes := uint64(0)
for _, x := range overfetch_ranges {
total_bytes += x.Rng.Length
total_actual_bytes := uint64(0)
for _, x := range tile_parts {
total_actual_bytes += x.Length
}

if !dry_run {

outfile, err := os.Create(output)
defer outfile.Close()

outfile.Truncate(127 + int64(len(new_root_bytes)) + int64(header.MetadataLength) + int64(len(new_leaves_bytes)) + int64(total_actual_bytes))

_, err = outfile.Write(header_bytes)
if err != nil {
return err
Expand Down Expand Up @@ -458,38 +469,74 @@ func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, r
"fetching chunks",
)

for _, or := range overfetch_ranges {
var mu sync.Mutex

downloadPart := func(or OverfetchRange) error {
tile_r, err := bucket.NewRangeReader(ctx, file, int64(source_tile_data_offset+or.Rng.SrcOffset), int64(or.Rng.Length), nil)
if err != nil {
return err
}
offset_writer := io.NewOffsetWriter(outfile, int64(header.TileDataOffset)+int64(or.Rng.DstOffset))

for _, cd := range or.CopyDiscards {
_, err := io.CopyN(io.MultiWriter(outfile, bar), tile_r, int64(cd.Wanted))

_, err := io.CopyN(io.MultiWriter(offset_writer, bar), tile_r, int64(cd.Wanted))
if err != nil {
return err
}

_, err = io.CopyN(io.MultiWriter(io.Discard, bar), tile_r, int64(cd.Discard))
_, err = io.CopyN(bar, tile_r, int64(cd.Discard))
if err != nil {
return err
}
}
tile_r.Close()
return nil
}
}

total_actual_bytes := uint64(0)
for _, x := range tile_parts {
total_actual_bytes += x.Length
errs, _ := errgroup.WithContext(ctx)

for i := 0; i < download_threads; i++ {
work_back := (i == 0 && download_threads > 1)
errs.Go(func() error {
done := false
var or OverfetchRange
for {
mu.Lock()
if overfetch_ranges.Len() == 0 {
done = true
} else {
if work_back {
or = overfetch_ranges.Remove(overfetch_ranges.Back()).(OverfetchRange)
} else {
or = overfetch_ranges.Remove(overfetch_ranges.Front()).(OverfetchRange)
}
}
mu.Unlock()
if done {
return nil
}
err := downloadPart(or)
if err != nil {
return err
}
}

return nil
})
}

err = errs.Wait()
if err != nil {
return err
}
}

fmt.Printf("Completed in %v seconds with 1 download thread.\n", time.Since(start))
fmt.Printf("Completed in %v with %v download threads.\n", time.Since(start), download_threads)
total_requests := 2 // header + root
total_requests += len(overfetch_leaves) // leaves
total_requests += num_overfetch_leaves // leaves
total_requests += 1 // metadata
total_requests += len(overfetch_ranges)
total_requests += num_overfetch_ranges
fmt.Printf("Extract required %d total requests.\n", total_requests)
fmt.Printf("Extract transferred %s (overfetch %v) for an archive size of %s\n", humanize.Bytes(total_bytes), overfetch, humanize.Bytes(total_actual_bytes))
fmt.Println("Verify your extract is usable at https://protomaps.github.io/PMTiles/")
Expand Down
26 changes: 15 additions & 11 deletions pmtiles/extract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ func TestMergeRanges(t *testing.T) {
ranges = append(ranges, SrcDstRange{0, 0, 50})
ranges = append(ranges, SrcDstRange{60, 60, 60})

result := MergeRanges(ranges, 0.1)

assert.Equal(t, 1, len(result))
assert.Equal(t, SrcDstRange{0, 0, 120}, result[0].Rng)
assert.Equal(t, 2, len(result[0].CopyDiscards))
assert.Equal(t, CopyDiscard{50, 10}, result[0].CopyDiscards[0])
assert.Equal(t, CopyDiscard{60, 0}, result[0].CopyDiscards[1])
result, total_transfer_bytes := MergeRanges(ranges, 0.1)

assert.Equal(t, 1, result.Len())
assert.Equal(t, uint64(120), total_transfer_bytes)
front := result.Front().Value.(OverfetchRange)
assert.Equal(t, SrcDstRange{0, 0, 120}, front.Rng)
assert.Equal(t, 2, len(front.CopyDiscards))
assert.Equal(t, CopyDiscard{50, 10}, front.CopyDiscards[0])
assert.Equal(t, CopyDiscard{60, 0}, front.CopyDiscards[1])
}

func TestMergeRangesMultiple(t *testing.T) {
Expand All @@ -156,9 +158,11 @@ func TestMergeRangesMultiple(t *testing.T) {
ranges = append(ranges, SrcDstRange{60, 60, 10})
ranges = append(ranges, SrcDstRange{80, 80, 10})

result := MergeRanges(ranges, 0.3)
assert.Equal(t, 1, len(result))
assert.Equal(t, SrcDstRange{0, 0, 90}, result[0].Rng)
assert.Equal(t, 3, len(result[0].CopyDiscards))
result, total_transfer_bytes := MergeRanges(ranges, 0.3)
front := result.Front().Value.(OverfetchRange)
assert.Equal(t, uint64(90), total_transfer_bytes)
assert.Equal(t, 1, result.Len())
assert.Equal(t, SrcDstRange{0, 0, 90}, front.Rng)
assert.Equal(t, 3, len(front.CopyDiscards))
fmt.Println(result)
}

0 comments on commit d935662

Please sign in to comment.