Skip to content

Commit

Permalink
Use views to avoid thread issues
Browse files Browse the repository at this point in the history
  • Loading branch information
meggart committed Jul 7, 2023
1 parent a83aef3 commit c483c18
Showing 1 changed file with 55 additions and 54 deletions.
109 changes: 55 additions & 54 deletions src/DAT/DAT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,11 @@ function runLoop(dc::DATConfig, showprog)
mapfun(CachingPool(workers()),allRanges, on_error=identity) do r
incaches, outcaches, args = getallargs(dc)
updateinars(dc, r, incaches)
innerLoop(r, args...)
if dc.ntr[myid()] > 1
innerLoop_threaded(r, args...)
else
innerLoop(r,args...)
end
writeoutars(dc, r, outcaches)
dc.do_gc && GC.gc()
end
Expand All @@ -700,7 +704,12 @@ function runLoop(dc::DATConfig, showprog)
mapfun = showprog ? progress_map : map
mapfun(allRanges) do r
updateinars(dc, r, incaches)
innerLoop(r, args...)
if dc.ntr[1] > 1
innerLoop_threaded(r, args...)
else
@show "Running nonthreaded"
innerLoop(r,args...)
end
writeoutars(dc, r, outcaches)
dc.do_gc && GC.gc()
end
Expand All @@ -726,7 +735,6 @@ getlaxvals(a::AllLoopAxes, cI, offscur) = (
function getallargs(dc::DATConfig)
incache, outcache = getCubeCache(dc)
filters = map(ic -> ic.desc.procfilter, dc.incubes)
inworkar, outworkar = generateworkarrays(dc)
axvals = if dc.include_loopvars
lax = (dc.LoopAxes...,)
AllLoopAxes(lax)
Expand Down Expand Up @@ -760,7 +768,7 @@ function getallargs(dc::DATConfig)
end
incache,
outcache,
(fu, inarsbc, outarsbc, filters, inworkar, outworkar, axvals, adda, kwa)
(fu, inarsbc, outarsbc, filters, axvals, adda, kwa)
end


Expand Down Expand Up @@ -1078,32 +1086,23 @@ function distributeLoopRanges(block_size::NTuple{N,Int}, loopR::NTuple{N,Int}, c
Iterators.product(allranges...)
end

function generateworkarrays(dc::DATConfig)
inwork = map(i -> getworkarray(i, dc.ntr[myid()]), dc.incubes)
outwork = map(i -> getworkarray(i, dc.ntr[myid()]), dc.outcubes)
inwork, outwork
end

function innercode(
f,
cI,
xinBC,
xoutBC,
filters,
inwork,
outwork,
axvalcreator,
offscur,
addargs,
kwargs,
)
ithr = Threads.threadid()
#Pick the correct array according to thread
myinwork = map(i -> i[ithr], inwork)
myoutwork = map(i -> i[ithr], outwork)
#Copy data into work arrays
foreach(myinwork, xinBC) do iw, x
YAXArrayBase.getdata(iw) .= view(x, cI.I...)
myinwork = map(xinBC) do x
view(x, cI.I...)
end
myoutwork = map(xoutBC) do x
view(x, cI.I...)
end
#Apply filters
mvs = map(docheck, filters, myinwork)
Expand All @@ -1116,57 +1115,59 @@ function innercode(
#Finally call the function
f(myoutwork..., myinwork..., laxval..., addargs...; kwargs...)
end
#Copy data into output array
foreach((iw, x) -> view(x, cI.I...) .= YAXArrayBase.getdata(iw), myoutwork, xoutBC)
end

using DataStructures: OrderedDict
using Base.Cartesian
@noinline function innerLoop_threaded(
loopRanges,
f,
xinBC,
xoutBC,
filters,
axvalcreator,
addargs,
kwargs,
)
offscur = map(i -> (first(i) - 1), loopRanges)
Threads.@threads for cI in CartesianIndices(map(i -> 1:length(i), loopRanges))
innercode(
f,
cI,
xinBC,
xoutBC,
filters,
axvalcreator,
offscur,
addargs,
kwargs,
)
end
end

@noinline function innerLoop(
loopRanges,
f,
xinBC,
xoutBC,
filters,
inwork,
outwork,
axvalcreator,
addargs,
kwargs,
)
offscur = map(i -> (first(i) - 1), loopRanges)
if length(inwork[1]) == 1
for cI in CartesianIndices(map(i -> 1:length(i), loopRanges))
innercode(
f,
cI,
xinBC,
xoutBC,
filters,
inwork,
outwork,
axvalcreator,
offscur,
addargs,
kwargs,
)
end
else
Threads.@threads :static for cI in CartesianIndices(map(i -> 1:length(i), loopRanges))
innercode(
f,
cI,
xinBC,
xoutBC,
filters,
inwork,
outwork,
axvalcreator,
offscur,
addargs,
kwargs,
)
end
for cI in CartesianIndices(map(i -> 1:length(i), loopRanges))
innercode(
f,
cI,
xinBC,
xoutBC,
filters,
axvalcreator,
offscur,
addargs,
kwargs,
)
end
end

Expand Down

0 comments on commit c483c18

Please sign in to comment.