Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take non-equidistant samples and adapt mass matrix to Fisher information #114

Merged
merged 9 commits into from
Aug 25, 2022

Conversation

mschauer
Copy link
Owner

Adresses #113 . Supercedes #112 .

@mschauer
Copy link
Owner Author

This is how it would work with your iterator, @cscherrer

M = PDMats.PDiagMat(ones(d))


Z = BouncyParticle(missing, # graphical structure 
    missing, # MAP estimate, unused
    rate, # momentum refreshment rate and sample saving rate 
    1-1/n, # momentum correlation / only gradually change momentum in refreshment/momentum update
    M, # metric
    missing # legacy
) 

and then the collect_sampler gets changed into

function collect_sampler2(t, sampler, n; adapt_mass=true, progress=true, progress_stops=20, ra_offset=0)
    if progress
        prg = Progress(progress_stops, 1)
    else
        prg = missing
    end
    stops = ismissing(prg) ? 0 : max(prg.n - 1, 0) # allow one stop for cleanup
    nstop = n/stops
    d = length(sampler.u0[2][1])

    x1 = t(sampler.u0[2][1])
    tv = chainvec(x1, n)
    ϕ = iterate(sampler)
    j = 1
    local state
    M = sampler.F.U
    if adapt_mass
       m = 1 ./ M.diag
    end
    while ϕ !== nothing && j < n
        j += 1
        val, state = ϕ
        tv[j] = t(val[2])
        if adapt_mass
            @. m =  m + (state[2]^2 - m)/(ra_offset + j-1) # running average shifted by offset
            state = ZZB.set_action(state, :invalid)
            v = state[1][2][2] # get velocity
            PDMats.whiten!(M, v)
            @. M.diag = 1/m
            PDMats.unwhiten!(M, v)
        end
        ϕ = iterate(sampler, state)
        if j > nstop
            nstop += n/stops
            next!(prg) 
        end 
    end
    ismissing(prg) || ProgressMeter.finish!(prg)
    tv, (;uT=state[1], acc=state[3][1], total=state[3][2], bound=state[4].c)
end

@mschauer mschauer merged commit f0f74e6 into master Aug 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant