@@ -26,6 +26,7 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s
26
26
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
27
27
static constexpr int depth_to_max_tensors[5 ] = {110 , 64 , 48 , 36 , 30 };
28
28
static constexpr int depth_to_max_blocks[5 ] = {320 , 320 , 320 , 320 , 320 };
29
+ static constexpr int depth_to_max_tensors_scalarlist[5 ] = {96 , 64 , 48 , 36 , 30 };
29
30
30
31
template <int n> struct TensorListMetadata
31
32
{
@@ -35,6 +36,15 @@ template<int n> struct TensorListMetadata
35
36
int block_to_chunk[depth_to_max_blocks[n-1 ]];
36
37
};
37
38
39
+ template <int n> struct TensorListScalarListMetadata
40
+ {
41
+ void * addresses[n][depth_to_max_tensors_scalarlist[n-1 ]];
42
+ int sizes[depth_to_max_tensors_scalarlist[n-1 ]];
43
+ double scalar_vals[depth_to_max_tensors_scalarlist[n-1 ]];
44
+ unsigned char block_to_tensor[depth_to_max_blocks[n-1 ]];
45
+ int block_to_chunk[depth_to_max_blocks[n-1 ]];
46
+ };
47
+
38
48
template <typename T, typename U, typename ... ArgTypes>
39
49
C10_LAUNCH_BOUNDS_1 (kBlockSize )
40
50
__global__ void
@@ -49,11 +59,71 @@ multi_tensor_apply_kernel(
49
59
template <int depth, typename T, typename ... ArgTypes>
50
60
void multi_tensor_apply (
51
61
std::vector<std::vector<at::Tensor>>& tensor_lists,
62
+ at::ArrayRef<double > scalars,
52
63
T callable,
53
64
ArgTypes... args) {
54
65
TORCH_CHECK (tensor_lists.size () == depth, " Number of tensor lists has to match the depth." );
55
66
const cuda::OptionalCUDAGuard device_guard (device_of (tensor_lists[0 ][0 ]));
67
+ size_t n_tensors = tensor_lists[0 ].size ();
68
+ TensorListScalarListMetadata<depth> tensorListMeta;
69
+
70
+ int loc_block_info = 0 ;
71
+ int loc_tensor_info = 0 ;
72
+ for (size_t t = 0 ; t < n_tensors; t++) {
73
+
74
+ tensorListMeta.scalar_vals [loc_tensor_info] = scalars[t];
75
+
76
+ tensorListMeta.sizes [loc_tensor_info] = tensor_lists[0 ][t].numel ();
77
+ for (int d = 0 ; d < depth; d++) {
78
+ tensorListMeta.addresses [d][loc_tensor_info] = tensor_lists[d][t].data_ptr ();
79
+ }
80
+ loc_tensor_info++;
81
+
82
+ int chunks = (tensor_lists[0 ][t].numel () + kChunkSize - 1 )/kChunkSize ;
83
+ for (int chunk = 0 ; chunk < chunks; chunk++) {
84
+ tensorListMeta.block_to_tensor [loc_block_info] = loc_tensor_info - 1 ;
85
+ tensorListMeta.block_to_chunk [loc_block_info] = chunk;
86
+ loc_block_info++;
87
+
88
+ bool tensors_full = (loc_tensor_info == depth_to_max_tensors_scalarlist[depth-1 ] &&
89
+ chunk == chunks - 1 );
90
+ bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1 ]);
91
+ bool last_chunk = (t == n_tensors - 1 && chunk == chunks - 1 );
92
+
93
+ if (tensors_full || blocks_full || last_chunk) {
94
+ multi_tensor_apply_kernel<<<loc_block_info, kBlockSize , 0 , at::cuda::getCurrentCUDAStream()>>> (
95
+ tensorListMeta,
96
+ callable,
97
+ args...);
98
+
99
+ AT_CUDA_CHECK (cudaGetLastError ());
100
+
101
+ // Reset.
102
+ loc_block_info = 0 ;
103
+ if (chunk == chunks - 1 ) {
104
+ loc_tensor_info = 0 ;
105
+ }
106
+ else {
107
+ tensorListMeta.sizes [0 ] = tensorListMeta.sizes [loc_tensor_info-1 ];
108
+ tensorListMeta.scalar_vals [0 ] = tensorListMeta.scalar_vals [loc_tensor_info-1 ];
109
+ for (int d = 0 ; d < depth; d++) {
110
+ tensorListMeta.addresses [d][0 ] = tensorListMeta.addresses [d][loc_tensor_info-1 ];
111
+ }
112
+ loc_tensor_info = 1 ;
113
+ }
114
+ }
115
+ }
116
+ }
117
+ }
118
+
56
119
120
+ template <int depth, typename T, typename ... ArgTypes>
121
+ void multi_tensor_apply (
122
+ std::vector<std::vector<at::Tensor>>& tensor_lists,
123
+ T callable,
124
+ ArgTypes... args) {
125
+ TORCH_CHECK (tensor_lists.size () == depth, " Number of tensor lists has to match the depth." );
126
+ const cuda::OptionalCUDAGuard device_guard (device_of (tensor_lists[0 ][0 ]));
57
127
size_t n_tensors = tensor_lists[0 ].size ();
58
128
TensorListMetadata<depth> tensorListMeta;
59
129
0 commit comments