@@ -52,6 +52,7 @@ class Primitives<
52
52
uint64_t connStepCache; // Cache last seen value of (*connStepPtr)
53
53
void * mhandle;
54
54
void * netDeviceHandle;
55
+ int timingIndex;
55
56
56
57
// Don't use barrier 0 as it's used by the final sync
57
58
__device__ void barrier () {
@@ -125,6 +126,10 @@ class Primitives<
125
126
return ld_volatile_global (ptr);
126
127
}
127
128
129
+ __forceinline__ __device__ bool shouldSaveTimings () {
130
+ return ncclShmem.kernelType != 0 && ncclShmem.channel .stepTimings != nullptr && timingIndex != -1 ;
131
+ }
132
+
128
133
template <int DirectRecv, int DirectSend, int Recv, int Send, int Src, int Dst>
129
134
__device__ __forceinline__ void waitPeer (intptr_t srcIx, intptr_t dstIx, int offset, int nelts) {
130
135
const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
@@ -141,8 +146,18 @@ class Primitives<
141
146
}
142
147
143
148
if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) {
144
- if (isSendNotRecv && (flags & SizesFifoEnabled))
149
+ if (shouldSaveTimings () && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
150
+ const int isSend = (flags & RoleWaitSend) != 0 ;
151
+ ncclShmem.channel .stepTimings [timingIndex + isSend] = step;
152
+ if ((Recv && !Send) ? (flags & RoleWaitRecv) : (flags & RoleWaitSend)) {
153
+ ncclShmem.channel .stepTimings [timingIndex + 2 ] = nelts*sizeof (T);
154
+ }
155
+ ncclShmem.channel .stepTimings [timingIndex + isSend + 3 ] = getDeviceTimeNs () - ncclShmem.startTiming ;
156
+ }
157
+
158
+ if (isSendNotRecv && (flags & SizesFifoEnabled)) {
145
159
connSizesFifoPtr[step%NCCL_STEPS] = nelts*sizeof (T);
160
+ }
146
161
147
162
void **ptrs = isSendNotRecv ? (ncclShmem.groups [group].dsts + Dst)
148
163
: (ncclShmem.groups [group].srcs + Src);
@@ -178,9 +193,16 @@ class Primitives<
178
193
template <int Recv, int Send>
179
194
inline __device__ void postPeer (bool dataStored) {
180
195
if (flags & (Recv*RolePostRecv | Send*RolePostSend)) {
196
+ const int isSend = (flags & RolePostSend) != 0 ;
197
+ if (shouldSaveTimings () && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT && isSend) {
198
+ ncclShmem.channel .stepTimings [timingIndex + 7 ] = getDeviceTimeNs () - ncclShmem.startTiming ;
199
+ }
181
200
step += StepPerSlice;
182
201
if (Send && (flags & RolePostSend) && dataStored) fence_acq_rel_sys ();
183
202
st_relaxed_sys_global (connStepPtr, step);
203
+ if (shouldSaveTimings () && timingIndex <= MAXSTEPTIMINGS*TIMINGS_COUNT) {
204
+ ncclShmem.channel .stepTimings [timingIndex + isSend + 5 ] = getDeviceTimeNs () - ncclShmem.startTiming ;
205
+ }
184
206
}
185
207
}
186
208
@@ -280,6 +302,9 @@ class Primitives<
280
302
postPeer<Recv, Send>(0 < sliceSize);
281
303
offset += sliceSize;
282
304
slice += 1 ;
305
+ if (shouldSaveTimings ()) {
306
+ timingIndex += TIMINGS_COUNT;
307
+ }
283
308
} while (slice < SlicePerChunk && offset < nelem);
284
309
}
285
310
@@ -298,6 +323,9 @@ class Primitives<
298
323
postPeer<Recv, Send>(0 < sliceSize);
299
324
offset += sliceSize;
300
325
slice += 1 ;
326
+ if (shouldSaveTimings ()) {
327
+ timingIndex += TIMINGS_COUNT;
328
+ }
301
329
}
302
330
}
303
331
@@ -471,7 +499,8 @@ class Primitives<
471
499
uint8_t connIndexRecv = 0 , uint8_t connIndexSend = 0 , struct ncclWorkElem * e = nullptr , int stepSize_=0
472
500
):
473
501
tid(tid), nthreads(nthreads), tidInBlock(threadIdx.x), group(group),
474
- stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof (T) : stepSize_) {
502
+ stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof (T) : stepSize_),
503
+ timingIndex(-1 ) {
475
504
476
505
// For send operations, we need an extra warp to overlap the threadfence and the copy
477
506
this ->nworkers = nthreads - (MaxSend > 0 && nthreads-WARP_SIZE >= 64 ? WARP_SIZE : 0 );
@@ -489,15 +518,35 @@ class Primitives<
489
518
index = tid % ThreadPerSync;
490
519
flags = 0 ;
491
520
if (g == 0 ) {
492
- if (index < nrecv) flags |= RoleWaitRecv;
521
+ if (index < nrecv) {
522
+ flags |= RoleWaitRecv;
523
+ if (index == 0 ) {
524
+ timingIndex = 1 ;
525
+ }
526
+ }
493
527
if (index == nrecv) flags |= RoleInput;
494
528
} else if (g == 1 ) {
495
- if (index < nsend) flags |= RoleWaitSend;
529
+ if (index < nsend) {
530
+ flags |= RoleWaitSend;
531
+ if (index == 0 ) {
532
+ timingIndex = 1 ;
533
+ }
534
+ }
496
535
if (index == nsend) flags |= RoleOutput;
497
536
} else if (g == ng - 2 ) {
498
- if (index < nrecv) flags |= RolePostRecv;
537
+ if (index < nrecv) {
538
+ flags |= RolePostRecv;
539
+ if (index == 0 ) {
540
+ timingIndex = 1 ;
541
+ }
542
+ }
499
543
} else if (g == ng - 1 ) {
500
- if (index < nsend) flags |= RolePostSend;
544
+ if (index < nsend) {
545
+ flags |= RolePostSend;
546
+ if (index == 0 ) {
547
+ timingIndex = 1 ;
548
+ }
549
+ }
501
550
}
502
551
503
552
int peer = 0 ;
@@ -532,6 +581,10 @@ class Primitives<
532
581
if (flags & (RolePostSend|RolePostRecv)) {
533
582
auto *conns = (flags & RolePostSend) ? ncclShmem.groups [group].sendConns : ncclShmem.groups [group].recvConns ;
534
583
conns[index ]->step = step;
584
+
585
+ if (flags & RolePostSend && shouldSaveTimings ()) {
586
+ ncclShmem.channel .stepTimings [0 ] = (static_cast <uint64_t >(ncclShmem.kernelType ) << 32 ) | min (timingIndex, 1 + MAXSTEPTIMINGS*TIMINGS_COUNT);
587
+ }
535
588
}
536
589
537
590
if ((flags & (AnyNetDeviceUnpack)) && (flags & (RoleWaitRecv))) {
0 commit comments