diff --git a/go.mod b/go.mod index ad8042f..be229d2 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/schollz/progressbar/v3 v3.11.0 github.com/stretchr/testify v1.8.1 gocloud.dev v0.34.0 + golang.org/x/sync v0.3.0 zombiezen.com/go/sqlite v0.10.1 ) @@ -71,7 +72,6 @@ require ( golang.org/x/crypto v0.11.0 // indirect golang.org/x/net v0.13.0 // indirect golang.org/x/oauth2 v0.10.0 // indirect - golang.org/x/sync v0.3.0 // indirect golang.org/x/sys v0.10.0 // indirect golang.org/x/term v0.10.0 // indirect golang.org/x/text v0.11.0 // indirect diff --git a/main.go b/main.go index c6f4a22..9f0a143 100644 --- a/main.go +++ b/main.go @@ -50,6 +50,7 @@ var cli struct { Bucket string `help:"Remote bucket of input archive."` Region string `help:"local GeoJSON Polygon or MultiPolygon file for area of interest." type:"existingfile"` Maxzoom uint8 `help:"Maximum zoom level, inclusive."` + DownloadThreads int `default:4 help:"Number of download threads."` DryRun bool `help:"Calculate tiles to extract, but don't download them."` Overfetch float32 `default:0.1 help:"What ratio of extra data to download to minimize # requests; 0.2 is 20%"` } `cmd:"" help:"Create an archive from a larger archive for a subset of zoom levels or geographic region."` @@ -120,7 +121,7 @@ func main() { logger.Printf("Serving %s %s on port %d with Access-Control-Allow-Origin: %s\n", cli.Serve.Bucket, cli.Serve.Path, cli.Serve.Port, cli.Serve.Cors) logger.Fatal(http.ListenAndServe(":"+strconv.Itoa(cli.Serve.Port), nil)) case "extract ": - err := pmtiles.Extract(logger, cli.Extract.Bucket, cli.Extract.Input, cli.Extract.Maxzoom, cli.Extract.Region, cli.Extract.Output, cli.Extract.Overfetch, cli.Extract.DryRun) + err := pmtiles.Extract(logger, cli.Extract.Bucket, cli.Extract.Input, cli.Extract.Maxzoom, cli.Extract.Region, cli.Extract.Output, cli.Extract.DownloadThreads, cli.Extract.Overfetch, cli.Extract.DryRun) if err != nil { logger.Fatalf("Failed to extract, %v", err) } diff --git a/pmtiles/extract.go b/pmtiles/extract.go index 745c66e..6e036ee 100644 --- a/pmtiles/extract.go +++ b/pmtiles/extract.go @@ -3,6 +3,7 @@ package pmtiles import ( "bytes" "context" + "container/list" "fmt" "github.com/RoaringBitmap/roaring/roaring64" "github.com/dustin/go-humanize" @@ -10,6 +11,7 @@ import ( "github.com/paulmach/orb/geojson" "github.com/schollz/progressbar/v3" "gocloud.dev/blob" + "golang.org/x/sync/errgroup" "io" "io/ioutil" "log" @@ -17,6 +19,7 @@ import ( "os" "sort" "strings" + "sync" "time" ) @@ -157,10 +160,10 @@ type OverfetchListItem struct { // input ranges are merged in order of smallest byte distance to next range // until the overfetch budget is consumed. // The slice is sorted by Length -func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange { +func MergeRanges(ranges []SrcDstRange, overfetch float32) (*list.List, uint64) { total_size := 0 - list := make([]*OverfetchListItem, len(ranges)) + shortest := make([]*OverfetchListItem, len(ranges)) // create the heap items for i, rng := range ranges { @@ -174,7 +177,7 @@ func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange { } } - list[i] = &OverfetchListItem{ + shortest[i] = &OverfetchListItem{ Rng: rng, BytesToNext: bytes_to_next, CopyDiscards: []CopyDiscard{{uint64(rng.Length), 0}}, @@ -183,21 +186,18 @@ func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange { } // make the list doubly-linked - for i, item := range list { + for i, item := range shortest { if i > 0 { - item.prev = list[i-1] + item.prev = shortest[i-1] } - if i < len(list)-1 { - item.next = list[i+1] + if i < len(shortest)-1 { + item.next = shortest[i+1] } } overfetch_budget := int(float32(total_size) * overfetch) - // create a 2nd slice, sorted by ascending distance to next range - shortest := make([]*OverfetchListItem, len(list)) - copy(shortest, list) - + // sort by ascending distance to next range sort.Slice(shortest, func(i, j int) bool { return shortest[i].BytesToNext < shortest[j].BytesToNext }) @@ -221,21 +221,21 @@ func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange { overfetch_budget -= int(item.BytesToNext) } - // copy out the result structs - result := make([]OverfetchRange, len(shortest)) - sort.Slice(shortest, func(i, j int) bool { - return shortest[i].Rng.DstOffset < shortest[j].Rng.DstOffset + return shortest[i].Rng.Length > shortest[j].Rng.Length }) - for i, x := range shortest { - result[i] = OverfetchRange{ + total_bytes := uint64(0) + result := list.New() + for _, x := range shortest { + result.PushBack(OverfetchRange{ Rng: x.Rng, CopyDiscards: x.CopyDiscards, - } + }) + total_bytes += x.Rng.Length } - return result + return result, total_bytes } // 1. Get the root directory (check that it is clustered) @@ -252,7 +252,7 @@ func MergeRanges(ranges []SrcDstRange, overfetch float32) []OverfetchRange { // 10. write the leaf directories (if any) // 11. Get all tiles, and write directly to the output. -func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, region_file string, output string, overfetch float32, dry_run bool) error { +func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, region_file string, output string, download_threads int, overfetch float32, dry_run bool) error { // 1. fetch the header if bucketURL == "" { @@ -342,10 +342,15 @@ func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, r leaf_ranges = append(leaf_ranges, SrcDstRange{header.LeafDirectoryOffset + leaf.Offset, 0, uint64(leaf.Length)}) } - overfetch_leaves := MergeRanges(leaf_ranges, overfetch) - fmt.Printf("fetching %d dirs, %d chunks, %d requests\n", len(leaves), len(leaf_ranges), len(overfetch_leaves)) + overfetch_leaves, _ := MergeRanges(leaf_ranges, overfetch) + num_overfetch_leaves := overfetch_leaves.Len() + fmt.Printf("fetching %d dirs, %d chunks, %d requests\n", len(leaves), len(leaf_ranges), overfetch_leaves.Len()) - for _, or := range overfetch_leaves { + for { + if overfetch_leaves.Len() == 0 { + break + } + or := overfetch_leaves.Remove(overfetch_leaves.Front()).(OverfetchRange) slab_r, err := bucket.NewRangeReader(ctx, file, int64(or.Rng.SrcOffset), int64(or.Rng.Length), nil) if err != nil { @@ -385,9 +390,12 @@ func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, r // we now need to re-encode this entry list using cumulative offsets reencoded, tile_parts, tiledata_length, addressed_tiles, tile_contents := ReencodeEntries(tile_entries) - overfetch_ranges := MergeRanges(tile_parts, overfetch) - fmt.Printf("fetching %d tiles, %d chunks, %d requests\n", len(reencoded), len(tile_parts), len(overfetch_ranges)) + overfetch_ranges, total_bytes := MergeRanges(tile_parts, overfetch) + num_overfetch_ranges := overfetch_ranges.Len() + fmt.Printf("fetching %d tiles, %d chunks, %d requests\n", len(reencoded), len(tile_parts), overfetch_ranges.Len()) + + // TODO: takes up too much RAM // construct the directories new_root_bytes, new_leaves_bytes, _ := optimize_directories(reencoded, 16384-HEADERV3_LEN_BYTES) @@ -414,15 +422,18 @@ func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, r header_bytes := serialize_header(header) - total_bytes := uint64(0) - for _, x := range overfetch_ranges { - total_bytes += x.Rng.Length + total_actual_bytes := uint64(0) + for _, x := range tile_parts { + total_actual_bytes += x.Length } if !dry_run { outfile, err := os.Create(output) defer outfile.Close() + + outfile.Truncate(127 + int64(len(new_root_bytes)) + int64(header.MetadataLength) + int64(len(new_leaves_bytes)) + int64(total_actual_bytes)) + _, err = outfile.Write(header_bytes) if err != nil { return err @@ -458,38 +469,74 @@ func Extract(logger *log.Logger, bucketURL string, file string, maxzoom uint8, r "fetching chunks", ) - for _, or := range overfetch_ranges { + var mu sync.Mutex + downloadPart := func(or OverfetchRange) error { tile_r, err := bucket.NewRangeReader(ctx, file, int64(source_tile_data_offset+or.Rng.SrcOffset), int64(or.Rng.Length), nil) if err != nil { return err } + offset_writer := io.NewOffsetWriter(outfile, int64(header.TileDataOffset)+int64(or.Rng.DstOffset)) for _, cd := range or.CopyDiscards { - _, err := io.CopyN(io.MultiWriter(outfile, bar), tile_r, int64(cd.Wanted)) + + _, err := io.CopyN(io.MultiWriter(offset_writer, bar), tile_r, int64(cd.Wanted)) if err != nil { return err } - _, err = io.CopyN(io.MultiWriter(io.Discard, bar), tile_r, int64(cd.Discard)) + _, err = io.CopyN(bar, tile_r, int64(cd.Discard)) if err != nil { return err } } tile_r.Close() + return nil } - } - total_actual_bytes := uint64(0) - for _, x := range tile_parts { - total_actual_bytes += x.Length + errs, _ := errgroup.WithContext(ctx) + + for i := 0; i < download_threads; i++ { + work_back := (i == 0 && download_threads > 1) + errs.Go(func() error { + done := false + var or OverfetchRange + for { + mu.Lock() + if overfetch_ranges.Len() == 0 { + done = true + } else { + if work_back { + or = overfetch_ranges.Remove(overfetch_ranges.Back()).(OverfetchRange) + } else { + or = overfetch_ranges.Remove(overfetch_ranges.Front()).(OverfetchRange) + } + } + mu.Unlock() + if done { + return nil + } + err := downloadPart(or) + if err != nil { + return err + } + } + + return nil + }) + } + + err = errs.Wait() + if err != nil { + return err + } } - fmt.Printf("Completed in %v seconds with 1 download thread.\n", time.Since(start)) + fmt.Printf("Completed in %v with %v download threads.\n", time.Since(start), download_threads) total_requests := 2 // header + root - total_requests += len(overfetch_leaves) // leaves + total_requests += num_overfetch_leaves // leaves total_requests += 1 // metadata - total_requests += len(overfetch_ranges) + total_requests += num_overfetch_ranges fmt.Printf("Extract required %d total requests.\n", total_requests) fmt.Printf("Extract transferred %s (overfetch %v) for an archive size of %s\n", humanize.Bytes(total_bytes), overfetch, humanize.Bytes(total_actual_bytes)) fmt.Println("Verify your extract is usable at https://protomaps.github.io/PMTiles/") diff --git a/pmtiles/extract_test.go b/pmtiles/extract_test.go index 7464b9c..403ccff 100644 --- a/pmtiles/extract_test.go +++ b/pmtiles/extract_test.go @@ -141,13 +141,15 @@ func TestMergeRanges(t *testing.T) { ranges = append(ranges, SrcDstRange{0, 0, 50}) ranges = append(ranges, SrcDstRange{60, 60, 60}) - result := MergeRanges(ranges, 0.1) - - assert.Equal(t, 1, len(result)) - assert.Equal(t, SrcDstRange{0, 0, 120}, result[0].Rng) - assert.Equal(t, 2, len(result[0].CopyDiscards)) - assert.Equal(t, CopyDiscard{50, 10}, result[0].CopyDiscards[0]) - assert.Equal(t, CopyDiscard{60, 0}, result[0].CopyDiscards[1]) + result, total_transfer_bytes := MergeRanges(ranges, 0.1) + + assert.Equal(t, 1, result.Len()) + assert.Equal(t, uint64(120), total_transfer_bytes) + front := result.Front().Value.(OverfetchRange) + assert.Equal(t, SrcDstRange{0, 0, 120}, front.Rng) + assert.Equal(t, 2, len(front.CopyDiscards)) + assert.Equal(t, CopyDiscard{50, 10}, front.CopyDiscards[0]) + assert.Equal(t, CopyDiscard{60, 0}, front.CopyDiscards[1]) } func TestMergeRangesMultiple(t *testing.T) { @@ -156,9 +158,11 @@ func TestMergeRangesMultiple(t *testing.T) { ranges = append(ranges, SrcDstRange{60, 60, 10}) ranges = append(ranges, SrcDstRange{80, 80, 10}) - result := MergeRanges(ranges, 0.3) - assert.Equal(t, 1, len(result)) - assert.Equal(t, SrcDstRange{0, 0, 90}, result[0].Rng) - assert.Equal(t, 3, len(result[0].CopyDiscards)) + result, total_transfer_bytes := MergeRanges(ranges, 0.3) + front := result.Front().Value.(OverfetchRange) + assert.Equal(t, uint64(90), total_transfer_bytes) + assert.Equal(t, 1, result.Len()) + assert.Equal(t, SrcDstRange{0, 0, 90}, front.Rng) + assert.Equal(t, 3, len(front.CopyDiscards)) fmt.Println(result) }