diff --git a/src/utils.jl b/src/utils.jl index 20d16596ed..a84afe2263 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -551,8 +551,9 @@ _eltype(::AbstractArray{T}) where T = T """ throttle(f, timeout; leading=true, trailing=false) -Return a function that when invoked, will only be triggered at most once -during `timeout` seconds. +Return a function that when called, will only call the given `f` at most +once during `timeout` seconds. Any arguments passed to this new function +are passed to `f`. Normally, the throttled function will run as much as it can, without ever going more than once per `wait` duration; but if you'd like to disable the @@ -561,17 +562,27 @@ the trailing edge, pass `trailing=true`. # Examples ```jldoctest -julia> a = Flux.throttle(() -> println("Flux"), 2); +julia> noarg = Flux.throttle(() -> println("Flux"), 2); -julia> for i = 1:4 # a called in alternate iterations - a() +julia> for i in 1:4 + noarg() # println called in alternate iterations sleep(1) end Flux Flux + +julia> onearg = Flux.throttle(i -> println("step = ", i), 1); + +julia> for i in 1:10 + onearg(i) + sleep(0.3) + end +step = 1 +step = 5 +step = 9 ``` """ -function throttle(f, timeout; leading=true, trailing=false) +function throttle(f, timeout::Real; leading=true, trailing=false) cooldown = true later = nothing result = nothing @@ -603,6 +614,44 @@ function throttle(f, timeout; leading=true, trailing=false) end end +""" + @throttle timeout expr + +Evaluates the given expression at most once every `timeout` seconds. + +Internally, it uses [`throttle`](@ref Flux.throttle). But instead of +defining a function outside the loop, it lets you place the code inside +the loop. + +# Example +```jldoctest +julia> for i in 1:20 + j = 100i + sleep(0.2) + Flux.@throttle 0.9 if iseven(i) + println("i = ", i, ", and j = ", j) + else + println("i = ", i) + end + end +i = 1 +i = 6, and j = 600 +i = 11 +i = 16, and j = 1600 +``` +""" +macro throttle(timeout::Real, ex) + expr = macroexpand(__module__, ex) + vars = unique(_allsymbols(expr)) + @gensym fast slow + Base.eval(__module__, :($fast($(vars...)) = $expr)) + Base.eval(__module__, :(const $slow = $throttle($fast, $timeout))) + :($slow($(vars...))) |> esc +end + +_allsymbols(s::Symbol) =[s] +_allsymbols(other) = Symbol[] +_allsymbols(ex::Expr) = vcat(_allsymbols.(ex.args)...) """ modules(m) @@ -675,7 +724,6 @@ julia> loss() = rand(); julia> trigger = Flux.patience(() -> loss() < 1, 3); - julia> for i in 1:10 @info "Epoch \$i" trigger() && break