Skip to content

Commit 8990132

Browse files
committed
Kernel-level profiling draft (see NVIDIA#1210 (comment))
1 parent fee4a6b commit 8990132

9 files changed

+250
-7
lines changed

src/device/all_gather.h

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ namespace {
1212
template<typename T, typename RedOp, typename Proto>
1313
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
1414
const int tid = threadIdx.x;
15+
16+
if (tid == 0) {
17+
ncclShmem.kernelType = 1;
18+
}
19+
__syncthreads();
20+
1521
const int nthreads = args->nWarps*WARP_SIZE;
1622
const int bid = args->bid;
1723
const int nChannels = args->nChannels;

src/device/common.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
#include "op128.h"
1313
#include "network/unpack/unpack_defs.h"
1414

15+
#include <cuda/std/chrono>
16+
17+
__device__ __forceinline__ uint64_t getDeviceTimeNs() {
18+
return cuda::std::chrono::high_resolution_clock::now().time_since_epoch().count();
19+
}
20+
1521
#define COLL_UNROLL (ncclCollUnroll())
1622

1723
typedef void(*ncclDevFuncPtr_t)();
@@ -38,6 +44,8 @@ struct ncclShmemData {
3844
alignas(16) union {
3945
unpackShmem unpack;
4046
} devicePlugin;
47+
alignas(16) uint64_t startTiming;
48+
alignas(16) int kernelType;
4149
};
4250
static_assert(offsetof(struct ncclShmemData, work)%16 == 0, "shmem.work needs to be 16B aligned");
4351

@@ -162,7 +170,11 @@ __device__ void ncclKernelMain(struct ncclDevComm* comm, uint64_t channelMask, s
162170
__syncthreads(); // publish ncclShmem.channelId
163171
int channelId = ncclShmem.channelId;
164172
/* set abort flag to 0 */
165-
if (tid == 0) ncclShmem.aborted = 0;
173+
if (tid == 0) {
174+
ncclShmem.aborted = 0;
175+
ncclShmem.startTiming = getDeviceTimeNs();
176+
ncclShmem.kernelType = 0;
177+
}
166178

167179
if (true) {
168180
void *dst, *src;

src/device/mixed_precision_reduce_scatter.h

+36
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ class MixedPrecisionReduceScatterPrims {
252252
postPeer<Send, Recv>(sliceSize > 0);
253253
offset += sliceSize;
254254
++slice;
255+
if (shouldSaveTimings()) {
256+
timingIndex += TIMINGS_COUNT;
257+
}
255258
} while (slice < SlicePerChunk && offset < nelem);
256259
}
257260

@@ -265,6 +268,9 @@ class MixedPrecisionReduceScatterPrims {
265268
postPeer<Send, Recv>(sliceSize > 0);
266269
offset += sliceSize;
267270
++slice;
271+
if (shouldSaveTimings()) {
272+
timingIndex += TIMINGS_COUNT;
273+
}
268274
}
269275
}
270276

@@ -282,6 +288,7 @@ class MixedPrecisionReduceScatterPrims {
282288
, stepSize(stepSize)
283289
, nranks(nranks)
284290
, size(size)
291+
, timingIndex(-1)
285292
{
286293
constexpr int ThreadsPerSync = 8;
287294

@@ -291,16 +298,20 @@ class MixedPrecisionReduceScatterPrims {
291298
// Preserve the indexes from `prims_simple.h` for simplicity
292299
if (tid == 0) {
293300
flags |= RoleWaitRecv;
301+
timingIndex = 1;
294302
} else if (tid == 1) {
295303
flags |= RoleInput;
296304
} else if (tid == ThreadsPerSync) {
297305
flags |= RoleWaitSend;
306+
timingIndex = 1;
298307
} else if (tid == ThreadsPerSync + 1) {
299308
flags |= RoleOutput;
300309
} else if (tid == nthreads - 2 * ThreadsPerSync) {
301310
flags |= RolePostRecv;
311+
timingIndex = 1;
302312
} else if (tid == nthreads - ThreadsPerSync) {
303313
flags |= RolePostSend;
314+
timingIndex = 1;
304315
}
305316

306317
loadConn(ring->prev, RolePostRecv, RoleWaitRecv, true /*isRecv*/);
@@ -323,12 +334,19 @@ class MixedPrecisionReduceScatterPrims {
323334
if (flags & (RolePostSend | RolePostRecv)) {
324335
auto& group = ncclShmem.groups[0];
325336
((flags & RolePostSend) ? group.sendConns : group.recvConns)[0]->step = step;
337+
if (flags & RolePostSend && shouldSaveTimings()) {
338+
ncclShmem.channel.stepTimings[0] = (2ULL << 32) | min(timingIndex, 1 + MAXSTEPTIMINGS*TIMINGS_COUNT);
339+
}
326340
}
327341

328342
barrier();
329343
}
330344

331345
private:
346+
__forceinline__ __device__ bool shouldSaveTimings() {
347+
return ncclShmem.channel.stepTimings != nullptr && timingIndex != -1;
348+
}
349+
332350
__device__ void barrier() {
333351
flags |= ThreadsSynced;
334352
asm volatile("bar.sync %0, %1;" :: "r"(15), "r"(nthreads) : "memory");
@@ -359,6 +377,15 @@ class MixedPrecisionReduceScatterPrims {
359377
}
360378
}
361379

380+
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
381+
const int isSend = (flags & RoleWaitSend) != 0;
382+
ncclShmem.channel.stepTimings[timingIndex + isSend] = step;
383+
if ((Recv && !Send) ? (flags & RoleWaitRecv) : (flags & RoleWaitSend)) {
384+
ncclShmem.channel.stepTimings[timingIndex + 2] = nelts*sizeof(T);
385+
}
386+
ncclShmem.channel.stepTimings[timingIndex + isSend + 3] = getDeviceTimeNs() - ncclShmem.startTiming;
387+
}
388+
362389
if (isSendNotRecv && (flags & SizesFifoEnabled)) {
363390
connSizesFifoPtr[step % NCCL_STEPS] = nelts * sizeof(T);
364391
}
@@ -382,11 +409,18 @@ class MixedPrecisionReduceScatterPrims {
382409
template <int Send, int Recv>
383410
__device__ void postPeer(bool dataStored) {
384411
if (flags & (Recv * RolePostRecv | Send * RolePostSend)) {
412+
const int isSend = (flags & RolePostSend) != 0;
413+
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT && isSend) {
414+
ncclShmem.channel.stepTimings[timingIndex + 7] = getDeviceTimeNs() - ncclShmem.startTiming;
415+
}
385416
step += StepPerSlice;
386417
if (Send && (flags & RolePostSend) && dataStored) {
387418
fence_acq_rel_sys();
388419
}
389420
st_relaxed_sys_global(connStepPtr, step);
421+
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
422+
ncclShmem.channel.stepTimings[timingIndex + isSend + 5] = getDeviceTimeNs() - ncclShmem.startTiming;
423+
}
390424
}
391425
}
392426

@@ -452,6 +486,8 @@ class MixedPrecisionReduceScatterPrims {
452486
int volatile *connSizesFifoPtr;
453487
uint64_t *connStepPtr;
454488
uint64_t connStepCache;
489+
490+
int timingIndex;
455491
};
456492

457493

src/device/prims_simple.h

+59-6
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class Primitives<
5252
uint64_t connStepCache; // Cache last seen value of (*connStepPtr)
5353
void* mhandle;
5454
void* netDeviceHandle;
55+
int timingIndex;
5556

5657
// Don't use barrier 0 as it's used by the final sync
5758
__device__ void barrier() {
@@ -125,6 +126,10 @@ class Primitives<
125126
return ld_volatile_global(ptr);
126127
}
127128

129+
__forceinline__ __device__ bool shouldSaveTimings() {
130+
return ncclShmem.kernelType != 0 && ncclShmem.channel.stepTimings != nullptr && timingIndex != -1;
131+
}
132+
128133
template <int DirectRecv, int DirectSend, int Recv, int Send, int Src, int Dst>
129134
__device__ __forceinline__ void waitPeer(intptr_t srcIx, intptr_t dstIx, int offset, int nelts) {
130135
const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
@@ -141,8 +146,18 @@ class Primitives<
141146
}
142147

143148
if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) {
144-
if (isSendNotRecv && (flags & SizesFifoEnabled))
149+
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
150+
const int isSend = (flags & RoleWaitSend) != 0;
151+
ncclShmem.channel.stepTimings[timingIndex + isSend] = step;
152+
if ((Recv && !Send) ? (flags & RoleWaitRecv) : (flags & RoleWaitSend)) {
153+
ncclShmem.channel.stepTimings[timingIndex + 2] = nelts*sizeof(T);
154+
}
155+
ncclShmem.channel.stepTimings[timingIndex + isSend + 3] = getDeviceTimeNs() - ncclShmem.startTiming;
156+
}
157+
158+
if (isSendNotRecv && (flags & SizesFifoEnabled)) {
145159
connSizesFifoPtr[step%NCCL_STEPS] = nelts*sizeof(T);
160+
}
146161

147162
void **ptrs = isSendNotRecv ? (ncclShmem.groups[group].dsts + Dst)
148163
: (ncclShmem.groups[group].srcs + Src);
@@ -178,9 +193,16 @@ class Primitives<
178193
template<int Recv, int Send>
179194
inline __device__ void postPeer(bool dataStored) {
180195
if (flags & (Recv*RolePostRecv | Send*RolePostSend)) {
196+
const int isSend = (flags & RolePostSend) != 0;
197+
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT && isSend) {
198+
ncclShmem.channel.stepTimings[timingIndex + 7] = getDeviceTimeNs() - ncclShmem.startTiming;
199+
}
181200
step += StepPerSlice;
182201
if (Send && (flags & RolePostSend) && dataStored) fence_acq_rel_sys();
183202
st_relaxed_sys_global(connStepPtr, step);
203+
if (shouldSaveTimings() && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
204+
ncclShmem.channel.stepTimings[timingIndex + isSend + 5] = getDeviceTimeNs() - ncclShmem.startTiming;
205+
}
184206
}
185207
}
186208

@@ -280,6 +302,9 @@ class Primitives<
280302
postPeer<Recv, Send>(0 < sliceSize);
281303
offset += sliceSize;
282304
slice += 1;
305+
if (shouldSaveTimings()) {
306+
timingIndex += TIMINGS_COUNT;
307+
}
283308
} while (slice < SlicePerChunk && offset < nelem);
284309
}
285310

@@ -298,6 +323,9 @@ class Primitives<
298323
postPeer<Recv, Send>(0 < sliceSize);
299324
offset += sliceSize;
300325
slice += 1;
326+
if (shouldSaveTimings()) {
327+
timingIndex += TIMINGS_COUNT;
328+
}
301329
}
302330
}
303331

@@ -471,7 +499,8 @@ class Primitives<
471499
uint8_t connIndexRecv = 0, uint8_t connIndexSend = 0, struct ncclWorkElem* e = nullptr, int stepSize_=0
472500
):
473501
tid(tid), nthreads(nthreads), tidInBlock(threadIdx.x), group(group),
474-
stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_) {
502+
stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_),
503+
timingIndex(-1) {
475504

476505
// For send operations, we need an extra warp to overlap the threadfence and the copy
477506
this->nworkers = nthreads - (MaxSend > 0 && nthreads-WARP_SIZE >= 64 ? WARP_SIZE : 0);
@@ -489,15 +518,35 @@ class Primitives<
489518
index = tid % ThreadPerSync;
490519
flags = 0;
491520
if (g == 0) {
492-
if (index < nrecv) flags |= RoleWaitRecv;
521+
if (index < nrecv) {
522+
flags |= RoleWaitRecv;
523+
if (index == 0) {
524+
timingIndex = 1;
525+
}
526+
}
493527
if (index == nrecv) flags |= RoleInput;
494528
} else if (g == 1) {
495-
if (index < nsend) flags |= RoleWaitSend;
529+
if (index < nsend) {
530+
flags |= RoleWaitSend;
531+
if (index == 0) {
532+
timingIndex = 1;
533+
}
534+
}
496535
if (index == nsend) flags |= RoleOutput;
497536
} else if (g == ng - 2) {
498-
if (index < nrecv) flags |= RolePostRecv;
537+
if (index < nrecv) {
538+
flags |= RolePostRecv;
539+
if (index == 0) {
540+
timingIndex = 1;
541+
}
542+
}
499543
} else if (g == ng - 1) {
500-
if (index < nsend) flags |= RolePostSend;
544+
if (index < nsend) {
545+
flags |= RolePostSend;
546+
if (index == 0) {
547+
timingIndex = 1;
548+
}
549+
}
501550
}
502551

503552
int peer = 0;
@@ -532,6 +581,10 @@ class Primitives<
532581
if (flags & (RolePostSend|RolePostRecv)) {
533582
auto *conns = (flags & RolePostSend) ? ncclShmem.groups[group].sendConns : ncclShmem.groups[group].recvConns;
534583
conns[index]->step = step;
584+
585+
if (flags & RolePostSend && shouldSaveTimings()) {
586+
ncclShmem.channel.stepTimings[0] = (static_cast<uint64_t>(ncclShmem.kernelType) << 32) | min(timingIndex, 1 + MAXSTEPTIMINGS*TIMINGS_COUNT);
587+
}
535588
}
536589

537590
if ((flags & (AnyNetDeviceUnpack)) && (flags & (RoleWaitRecv))) {

src/device/reduce_scatter.h

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ namespace {
1212
template<typename T, typename RedOp, typename Proto>
1313
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
1414
const int tid = threadIdx.x;
15+
16+
if (tid == 0) {
17+
ncclShmem.kernelType = 2;
18+
}
19+
__syncthreads();
20+
1521
const int nthreads = args->nWarps*WARP_SIZE;
1622
const int bid = args->bid;
1723
const int nChannels = args->nChannels;

0 commit comments

Comments
 (0)