diff --git a/examples/Scalars.jl b/examples/Scalars.jl index 36461c4a..e44a3f30 100644 --- a/examples/Scalars.jl +++ b/examples/Scalars.jl @@ -23,3 +23,28 @@ with_logger(logger) do @info "scalar/complex" y = z end end + + +################control step increments with context################ +with_logger(logger) do + for epoch in 1:10 + for i=1:100 + # increments global_step by default + with_TBLogger_hold_step() do + # all of these are logged at the same global_step + # and the logger global_step is only then increased + @info "train1/scalar" val=i + @info "train2/scalar" val2=i/2 + @info "train3/scalar" val3=100-i + end + end + # step increment at end can be disabled for easy train/test sync + with_TBLogger_hold_step(;step_at_end=false) do + # all of these are logged at the same global_step + # and the logger global_step is only then increased + @info "test1/scalar" epoch=epoch + @info "test2/scalar" epoch2=epoch^2 + @info "test3/scalar" epoch3=epoch^3 + end + end +end diff --git a/src/TBLogger.jl b/src/TBLogger.jl index d998160a..dc80994b 100644 --- a/src/TBLogger.jl +++ b/src/TBLogger.jl @@ -298,3 +298,53 @@ Base.show(io::IO, mime::MIME"text/plain", tbl::TBLogger) = begin """ Base.print(io, str) end + +""" +`with_TBLogger_hold_step(f, [step]; step_at_end::Bool=true)` +Context function to ease control of logging steps and synchronization. +Amount of step increment can be controlled via `set_step_increment!``. + +Example: +```julia +with_logger(lg) do + for epoch in 1:10 + for i=1:100 + # increments global_step by default + with_TBLogger_hold_step() do + # all of these are logged at the same global_step + # and the logger global_step is only then increased + @info "train1/scalar" i=i + @info "train2/scalar" i2=i^2 + @info "train3/scalar" i3=i^3 + end + end + # step increment at end can be disabled for easy train/test sync + with_TBLogger_hold_step(;step_at_end=false) do + # all of these are logged at the same global_step + # and the logger global_step is only then increased + @info "test1/scalar" i=i + @info "test2/scalar" i2=i^2 + @info "test3/scalar" i3=i^3 + end + end +end +``` + +""" +function with_TBLogger_hold_step(f, step::Int; step_at_end::Bool=true) + logger = CoreLogging.current_logger() + @assert logger isa TBLogger "with_TBLogger_hold_step: current logger is not a TBLogger, cannot establish current step automatically" + curr_step = logger.global_step + curr_increment = logger.step_increment + set_step!(logger, step) + set_step_increment!(logger, 0) + f() + set_step!(logger, curr_step) + set_step_increment!(logger, curr_increment) + step_at_end && increment_step!(logger, curr_increment) +end +function with_TBLogger_hold_step(f; step_at_end::Bool=true) + logger = CoreLogging.current_logger() + isa(logger, TBLogger) || error("with_TBLogger_hold_step: current logger is not a TBLogger, cannot establish current step automatically") + with_TBLogger_hold_step(f, logger.global_step; step_at_end=step_at_end) +end \ No newline at end of file diff --git a/src/TensorBoardLogger.jl b/src/TensorBoardLogger.jl index 758d74d0..0a1dd5d1 100644 --- a/src/TensorBoardLogger.jl +++ b/src/TensorBoardLogger.jl @@ -19,7 +19,7 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info, handle_message, shouldlog, min_enabled_level, catch_exceptions, with_logger, NullLogger -export TBLogger, reset!, set_step!, increment_step!, set_step_increment! +export TBLogger, reset!, set_step!, increment_step!, set_step_increment!, with_TBLogger_hold_step export log_histogram, log_value, log_vector, log_text, log_image, log_images, log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar export map_summaries, TBReader