Skip to content

Use thread adoption to handle log messages. #2754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 29 additions & 52 deletions lib/cublas/CUBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,56 +210,45 @@ end

## logging

const MAX_LOG_BUFLEN = UInt(1024*1024)
const log_buffer = Vector{UInt8}(undef, MAX_LOG_BUFLEN)
const log_cursor = Threads.Atomic{UInt}(0)
const log_cond = Ref{Base.AsyncCondition}() # root
# CUBLAS calls the log callback multiple times for each message, so we need to buffer them
const log_buffer = IOBuffer()

function log_message(ptr)
# NOTE: this function may be called from unmanaged threads (by cublasXt),
# so we can't even allocate, let alone perform I/O.
len = @ccall strlen(ptr::Cstring)::Csize_t
old_cursor = log_cursor[]
new_cursor = old_cursor + len+1
if new_cursor >= MAX_LOG_BUFLEN
# overrun
return
end
global log_buffer
str = unsafe_string(ptr)

@ccall memmove((pointer(log_buffer)+old_cursor)::Ptr{Nothing},
pointer(ptr)::Ptr{Nothing}, (len+1)::Csize_t)::Nothing
log_cursor[] = new_cursor # the consumer handles CAS'ing this value
# flush if we've started a new log message
if startswith(str, r"[A-Z]!")
flush_log_messages()
end

# avoid code that depends on the runtime (even the unsafe_convert from ccall does?!)
assume(isassigned(log_cond))
@ccall uv_async_send(log_cond[].handle::Ptr{Nothing})::Cint
# append the lines to the buffer
println(log_buffer, str)

return
end

function _log_message(blob)
function flush_log_messages()
global log_buffer
message = String(take!(log_buffer))
isempty(message) && return

# the message format isn't documented, but it looks like a message starts with a capital
# and the severity (e.g. `I!`), and subsequent lines start with a lowercase mark (`!i`)
#
# lines are separated by a \0 if they came in separately, but there may also be multiple
# actual lines separated by \n in each message.
for message in split(blob, r"[\0\n]+(?=[A-Z]!)")
code = message[1]
lines = split(message[3:end], r"[\0\n]+[a-z]!")
submessage = join(lines, '\n')
if code == 'I'
@debug submessage
elseif code == 'W'
@warn submessage
elseif code == 'E'
@error submessage
elseif code == 'F'
error(submessage)
else
@info "Unknown log message, please file an issue.\n$message"
end
code = message[1]
lines = split(message[3:end], r"\n+[a-z]!")
message = join(strip.(lines), '\n')
if code == 'I'
@debug message
elseif code == 'W'
@warn message
elseif code == 'E'
@error message
elseif code == 'F'
error(message)
else
@info "Unknown log message, please file an issue.\n$message"
end
return
end

function __init__()
Expand All @@ -273,21 +262,9 @@ function __init__()
# register a log callback
if !Sys.iswindows() && # NVIDIA bug #3321130 &&
!precompiling && (isdebug(:init, CUBLAS) || Base.JLOptions().debug_level >= 2)
log_cond[] = Base.AsyncCondition() do async_cond
blob = ""
while true
message_length = log_cursor[]
blob = unsafe_string(pointer(log_buffer), message_length)
if Threads.atomic_cas!(log_cursor, message_length, UInt(0)) == message_length
break
end
end
_log_message(blob)
return
end

callback = @cfunction(log_message, Nothing, (Cstring,))
cublasSetLoggerCallback(callback)
atexit(flush_log_messages)
end
end

Expand Down
31 changes: 5 additions & 26 deletions lib/cudnn/src/cuDNN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,6 @@ end

## logging

const log_messages = []
const log_lock = ReentrantLock()
const log_cond = Ref{Any}() # root

function log_message(sev, udata, dbg_ptr, ptr)
dbg = unsafe_load(dbg_ptr)

Expand All @@ -131,20 +127,11 @@ function log_message(sev, udata, dbg_ptr, ptr)
end
len += 1
end
str = unsafe_string(ptr, len) # XXX: can this yield?

# print asynchronously
Base.@lock log_lock begin
push!(log_messages, (; sev, dbg, str))
end
ccall(:uv_async_send, Cint, (Ptr{Cvoid},), udata)
str = unsafe_string(ptr, len)

return
end

function _log_message(sev, dbg, str)
# split into lines and report
lines = split(str, '\0')
msg = join(lines, '\n')
msg = join(strip.(lines), '\n')
if sev == CUDNN_SEV_INFO
@debug msg
elseif sev == CUDNN_SEV_WARNING
Expand All @@ -154,6 +141,7 @@ function _log_message(sev, dbg, str)
elseif sev == CUDNN_SEV_FATAL
error(msg)
end

return
end

Expand Down Expand Up @@ -182,18 +170,9 @@ function __init__()

# register a log callback
if !precompiling && (isdebug(:init, cuDNN) || Base.JLOptions().debug_level >= 2)
log_cond[] = Base.AsyncCondition() do async_cond
Base.@lock log_lock begin
while length(log_messages) > 0
message = popfirst!(log_messages)
_log_message(message...)
end
end
end

callback = @cfunction(log_message, Nothing,
(cudnnSeverity_t, Ptr{Cvoid}, Ptr{cudnnDebug_t}, Ptr{UInt8}))
cudnnSetCallback(typemax(UInt32), log_cond[], callback)
cudnnSetCallback(typemax(UInt32), C_NULL, callback)
end

_initialized[] = true
Expand Down