Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit a655116

Browse files
lwfacebook-github-bot
authored andcommittedSep 15, 2020
Clear up ownership for shm segments and ringbuffers
Summary: I found it odd that we were using shared_ptrs to store some parts of our ringbuffer/shm-segment code. I can see two reasons: - It allows us to have both the consumer and producer hold a shared_ptr to the ringbuffer, so that one could have multiple of them sharing the same ringbuffer, which would be kept alive as long as there were users. In practice I don't think this mattered: each ringbuffer had only one user, both "in the code" (one consumer or one producer, never both) and "logically" (it was owned by one side of the connection, or by reactor, ... and destroyed when they were gone). shared_ptrs weren't buying us anything and in fact were weakening the ownership model. - Its aliasing constructor allowed us to bind the lifetime of the shared memory segment to the one of the ringbuffer object that was stored in it, thus allowing us to keep a bunch of things alive by holding just one pointer. So in the end this resulted in a lot of implicit links, with the end user (the connection, the reactor, ...) holding just a shared_ptr to a consumer/producer and this in turn owning (in a pretend-shared way) a lot of resources, including mmapped memory and file descriptors. Here I try to make the ownership model more explicit, by separating the resource (the Segment class, which now becomes a movable RAII wrapper which owns the fd and the mmap) from the shallow and stateless helpers used to access it (the ringbuffer, which is a pair of pointers, and the consumer/producer, which are little more than a collection of methods). Consumers/producers are in fact now so simple that they don't have a common base class anymore and are created on-demand when needed. This removes the usage of shared_ptrs entirely, and I think clarifies the ownership and gives us stronger guarantees that resources will be cleaned up and when that'll occur. Also, it allows us to store the ringbuffer header on the stack or as a field of another object without having to do some shared_ptr contortions (an aliased shared_ptr with a custom empty destructor). This will come in handy later, for InfiniBand. Reviewed By: beauby Differential Revision: D23567125 fbshipit-source-id: c229c7f96655324787eda02c8545ab7bafadb885
1 parent feaf9f3 commit a655116

File tree

13 files changed

+317
-302
lines changed

13 files changed

+317
-302
lines changed
 

‎tensorpipe/test/util/ringbuffer/ringbuffer_test.cc

+98-85
Large diffs are not rendered by default.

‎tensorpipe/test/util/ringbuffer/shm_ringbuffer_test.cc

+26-12
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,22 @@
2222
#include <gtest/gtest.h>
2323

2424
using namespace tensorpipe::util::ringbuffer;
25+
using namespace tensorpipe::util::shm;
2526
using namespace tensorpipe::transport::shm;
2627

2728
// Same process produces and consumes share memory through different mappings.
2829
TEST(ShmRingBuffer, SameProducerConsumer) {
29-
// This must stay alive for the file descriptors to remain open.
30-
std::shared_ptr<RingBuffer> producer_rb;
3130
int header_fd = -1;
3231
int data_fd = -1;
3332
{
3433
// Producer part.
3534
// Buffer large enough to fit all data and persistent
3635
// (needs to be unlinked up manually).
37-
std::tie(header_fd, data_fd, producer_rb) = shm::create(256 * 1024);
38-
Producer prod{producer_rb};
36+
Segment header_segment;
37+
Segment data_segment;
38+
RingBuffer rb;
39+
std::tie(header_segment, data_segment, rb) = shm::create(256 * 1024);
40+
Producer prod{rb};
3941

4042
// Producer loop. It all fits in buffer.
4143
int i = 0;
@@ -44,12 +46,20 @@ TEST(ShmRingBuffer, SameProducerConsumer) {
4446
EXPECT_EQ(ret, sizeof(i));
4547
++i;
4648
}
49+
50+
// Duplicate the file descriptors so that the shared memory remains alive
51+
// when the original fds are closed by the segments' destructors.
52+
header_fd = ::dup(header_segment.getFd());
53+
data_fd = ::dup(data_segment.getFd());
4754
}
4855

4956
{
5057
// Consumer part.
5158
// Map file again (to a different address) and consume it.
52-
auto rb = shm::load(header_fd, data_fd);
59+
Segment header_segment;
60+
Segment data_segment;
61+
RingBuffer rb;
62+
std::tie(header_segment, data_segment, rb) = shm::load(header_fd, data_fd);
5363
Consumer cons{rb};
5464

5565
int i = 0;
@@ -84,16 +94,17 @@ TEST(ShmRingBuffer, SingleProducer_SingleConsumer) {
8494

8595
if (pid == 0) {
8696
// child, the producer
87-
// Make a scope so shared_ptr's are released even on exit(0).
97+
// Make a scope so segments are destroyed even on exit(0).
8898
{
89-
int header_fd;
90-
int data_fd;
91-
std::shared_ptr<RingBuffer> rb;
92-
std::tie(header_fd, data_fd, rb) = shm::create(1024);
99+
Segment header_segment;
100+
Segment data_segment;
101+
RingBuffer rb;
102+
std::tie(header_segment, data_segment, rb) = shm::create(1024);
93103
Producer prod{rb};
94104

95105
{
96-
auto err = sendFdsToSocket(sock_fds[0], header_fd, data_fd);
106+
auto err = sendFdsToSocket(
107+
sock_fds[0], header_segment.getFd(), data_segment.getFd());
97108
if (err) {
98109
TP_THROW_ASSERT() << err.what();
99110
}
@@ -134,7 +145,10 @@ TEST(ShmRingBuffer, SingleProducer_SingleConsumer) {
134145
TP_THROW_ASSERT() << err.what();
135146
}
136147
}
137-
auto rb = shm::load(header_fd, data_fd);
148+
Segment header_segment;
149+
Segment data_segment;
150+
RingBuffer rb;
151+
std::tie(header_segment, data_segment, rb) = shm::load(header_fd, data_fd);
138152
Consumer cons{rb};
139153

140154
int i = 0;

‎tensorpipe/test/util/shm/segment_test.cc

+23-18
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,29 @@ TEST(Segment, SameProducerConsumer_Scalar) {
2929
sched_setaffinity(0, sizeof(cpu_set_t), &cpuset);
3030

3131
// This must stay alive for the file descriptor to remain open.
32-
std::shared_ptr<Segment> producer_segment;
32+
int fd = -1;
3333
{
3434
// Producer part.
35-
std::shared_ptr<int> my_int_ptr;
36-
std::tie(my_int_ptr, producer_segment) =
35+
Segment segment;
36+
int* my_int_ptr;
37+
std::tie(segment, my_int_ptr) =
3738
Segment::create<int>(true, PageType::Default);
3839
int& my_int = *my_int_ptr;
3940
my_int = 1000;
41+
42+
// Duplicate the file descriptor so that the shared memory remains alive
43+
// when the original fd is closed by the segment's destructor.
44+
fd = ::dup(segment.getFd());
4045
}
4146

4247
{
4348
// Consumer part.
4449
// Map file again (to a different address) and consume it.
45-
std::shared_ptr<int> my_int_ptr;
46-
std::shared_ptr<Segment> segment;
47-
std::tie(my_int_ptr, segment) =
48-
Segment::load<int>(producer_segment->getFd(), true, PageType::Default);
49-
EXPECT_EQ(segment->getSize(), sizeof(int));
50+
Segment segment;
51+
int* my_int_ptr;
52+
std::tie(segment, my_int_ptr) =
53+
Segment::load<int>(fd, true, PageType::Default);
54+
EXPECT_EQ(segment.getSize(), sizeof(int));
5055
EXPECT_EQ(*my_int_ptr, 1000);
5156
}
5257
};
@@ -78,17 +83,17 @@ TEST(SegmentManager, SingleProducer_SingleConsumer_Array) {
7883
{
7984
// use huge pages in creation and not in loading. This should only affects
8085
// TLB overhead.
81-
std::shared_ptr<float> my_floats;
82-
std::shared_ptr<Segment> segment;
83-
std::tie(my_floats, segment) =
86+
Segment segment;
87+
float* my_floats;
88+
std::tie(segment, my_floats) =
8489
Segment::create<float[]>(num_floats, true, PageType::HugeTLB_2MB);
8590

8691
for (int i = 0; i < num_floats; ++i) {
87-
my_floats.get()[i] = i;
92+
my_floats[i] = i;
8893
}
8994

9095
{
91-
auto err = sendFdsToSocket(sock_fds[0], segment->getFd());
96+
auto err = sendFdsToSocket(sock_fds[0], segment.getFd());
9297
if (err) {
9398
TP_THROW_ASSERT() << err.what();
9499
}
@@ -112,13 +117,13 @@ TEST(SegmentManager, SingleProducer_SingleConsumer_Array) {
112117
TP_THROW_ASSERT() << err.what();
113118
}
114119
}
115-
std::shared_ptr<float> my_floats;
116-
std::shared_ptr<Segment> segment;
117-
std::tie(my_floats, segment) =
120+
Segment segment;
121+
float* my_floats;
122+
std::tie(segment, my_floats) =
118123
Segment::load<float[]>(segment_fd, false, PageType::Default);
119-
EXPECT_EQ(num_floats * sizeof(float), segment->getSize());
124+
EXPECT_EQ(num_floats * sizeof(float), segment.getSize());
120125
for (int i = 0; i < num_floats; ++i) {
121-
EXPECT_EQ(my_floats.get()[i], i);
126+
EXPECT_EQ(my_floats[i], i);
122127
}
123128
{
124129
uint64_t c = 1;

‎tensorpipe/transport/shm/connection.cc

+16-13
Original file line numberDiff line numberDiff line change
@@ -361,13 +361,15 @@ class Connection::Impl : public std::enable_shared_from_this<Connection::Impl>,
361361
ClosingReceiver closingReceiver_;
362362

363363
// Inbox.
364-
int inboxHeaderFd_;
365-
int inboxDataFd_;
366-
optional<util::ringbuffer::Consumer> inbox_;
364+
util::shm::Segment inboxHeaderSegment_;
365+
util::shm::Segment inboxDataSegment_;
366+
util::ringbuffer::RingBuffer inboxRb_;
367367
optional<Reactor::TToken> inboxReactorToken_;
368368

369369
// Outbox.
370-
optional<util::ringbuffer::Producer> outbox_;
370+
util::shm::Segment outboxHeaderSegment_;
371+
util::shm::Segment outboxDataSegment_;
372+
util::ringbuffer::RingBuffer outboxRb_;
371373
optional<Reactor::TToken> outboxReactorToken_;
372374

373375
// Peer trigger/tokens.
@@ -514,10 +516,8 @@ void Connection::Impl::initFromLoop() {
514516
}
515517

516518
// Create ringbuffer for inbox.
517-
std::shared_ptr<util::ringbuffer::RingBuffer> inboxRingBuffer;
518-
std::tie(inboxHeaderFd_, inboxDataFd_, inboxRingBuffer) =
519+
std::tie(inboxHeaderSegment_, inboxDataSegment_, inboxRb_) =
519520
util::ringbuffer::shm::create(kBufferSize);
520-
inbox_.emplace(std::move(inboxRingBuffer));
521521

522522
// Register method to be called when our peer writes to our inbox.
523523
inboxReactorToken_ = context_->addReaction(runIfAlive(*this, [](Impl& impl) {
@@ -919,8 +919,9 @@ void Connection::Impl::handleEventInFromLoop() {
919919
}
920920

921921
// Load ringbuffer for outbox.
922-
outbox_.emplace(util::ringbuffer::shm::load(
923-
outboxHeaderFd.release(), outboxDataFd.release()));
922+
std::tie(outboxHeaderSegment_, outboxDataSegment_, outboxRb_) =
923+
util::ringbuffer::shm::load(
924+
outboxHeaderFd.release(), outboxDataFd.release());
924925

925926
// Initialize remote reactor trigger.
926927
peerReactorTrigger_.emplace(
@@ -963,8 +964,8 @@ void Connection::Impl::handleEventOutFromLoop() {
963964
outboxReactorToken_.value(),
964965
reactorHeaderFd,
965966
reactorDataFd,
966-
inboxHeaderFd_,
967-
inboxDataFd_);
967+
inboxHeaderSegment_.getFd(),
968+
inboxDataSegment_.getFd());
968969
if (err) {
969970
setError_(std::move(err));
970971
return;
@@ -988,10 +989,11 @@ void Connection::Impl::processReadOperationsFromLoop() {
988989
return;
989990
}
990991
// Serve read operations
992+
util::ringbuffer::Consumer inboxConsumer(inboxRb_);
991993
while (!readOperations_.empty()) {
992994
auto readOperation = std::move(readOperations_.front());
993995
readOperations_.pop_front();
994-
if (readOperation.handleRead(*inbox_)) {
996+
if (readOperation.handleRead(inboxConsumer)) {
995997
peerReactorTrigger_->run(peerOutboxReactorToken_.value());
996998
}
997999
if (!readOperation.completed()) {
@@ -1008,10 +1010,11 @@ void Connection::Impl::processWriteOperationsFromLoop() {
10081010
return;
10091011
}
10101012

1013+
util::ringbuffer::Producer outboxProducer(outboxRb_);
10111014
while (!writeOperations_.empty()) {
10121015
auto writeOperation = std::move(writeOperations_.front());
10131016
writeOperations_.pop_front();
1014-
if (writeOperation.handleWrite(*outbox_)) {
1017+
if (writeOperation.handleWrite(outboxProducer)) {
10151018
peerReactorTrigger_->run(peerInboxReactorToken_.value());
10161019
}
10171020
if (!writeOperation.completed()) {

‎tensorpipe/transport/shm/reactor.cc

+13-22
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,8 @@ void writeToken(util::ringbuffer::Producer& producer, Reactor::TToken token) {
4343
} // namespace
4444

4545
Reactor::Reactor() {
46-
int headerFd;
47-
int dataFd;
48-
std::shared_ptr<util::ringbuffer::RingBuffer> rb;
49-
std::tie(headerFd, dataFd, rb) = util::ringbuffer::shm::create(kSize);
50-
headerFd_ = Fd(headerFd);
51-
dataFd_ = Fd(dataFd);
52-
consumer_.emplace(rb);
53-
producer_.emplace(rb);
46+
std::tie(headerSegment_, dataSegment_, rb_) =
47+
util::ringbuffer::shm::create(kSize);
5448
thread_ = std::thread(&Reactor::run, this);
5549
}
5650

@@ -106,23 +100,19 @@ void Reactor::remove(TToken token) {
106100
functionCount_--;
107101
}
108102

109-
void Reactor::trigger(TToken token) {
110-
std::unique_lock<std::mutex> lock(mutex_);
111-
writeToken(producer_.value(), token);
112-
}
113-
114103
std::tuple<int, int> Reactor::fds() const {
115-
return std::make_tuple(headerFd_.fd(), dataFd_.fd());
104+
return std::make_tuple(headerSegment_.getFd(), dataSegment_.getFd());
116105
}
117106

118107
void Reactor::run() {
119108
setThreadName("TP_SHM_reactor");
120109

110+
util::ringbuffer::Consumer reactorConsumer(rb_);
121111
// Stop when another thread has asked the reactor the close and when
122112
// all functions have been removed.
123113
while (!closed_ || functionCount_ > 0) {
124114
uint32_t token;
125-
auto ret = consumer_->read(&token, sizeof(token));
115+
auto ret = reactorConsumer.read(&token, sizeof(token));
126116
if (ret == -ENODATA) {
127117
if (deferredFunctionCount_ > 0) {
128118
decltype(deferredFunctionList_) fns;
@@ -179,15 +169,16 @@ void Reactor::run() {
179169
}
180170
}
181171

182-
Reactor::Trigger::Trigger(Fd&& headerFd, Fd&& dataFd)
183-
: producer_(util::ringbuffer::shm::load(
184-
// The header and data segment objects take over ownership
185-
// of file descriptors. Release them to avoid double close.
186-
headerFd.release(),
187-
dataFd.release())) {}
172+
Reactor::Trigger::Trigger(Fd&& headerFd, Fd&& dataFd) {
173+
// The header and data segment objects take over ownership
174+
// of file descriptors. Release them to avoid double close.
175+
std::tie(headerSegment_, dataSegment_, rb_) =
176+
util::ringbuffer::shm::load(headerFd.release(), dataFd.release());
177+
}
188178

189179
void Reactor::Trigger::run(TToken token) {
190-
writeToken(producer_, token);
180+
util::ringbuffer::Producer producer(rb_);
181+
writeToken(producer, token);
191182
}
192183

193184
void Reactor::deferToLoop(TDeferredFunction fn) {

‎tensorpipe/transport/shm/reactor.h

+7-8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <tensorpipe/transport/shm/fd.h>
2323
#include <tensorpipe/util/ringbuffer/consumer.h>
2424
#include <tensorpipe/util/ringbuffer/producer.h>
25+
#include <tensorpipe/util/shm/segment.h>
2526

2627
namespace tensorpipe {
2728
namespace transport {
@@ -87,9 +88,6 @@ class Reactor final {
8788
// Removes function associated with token from reactor.
8889
void remove(TToken token);
8990

90-
// Trigger reactor with specified token.
91-
void trigger(TToken token);
92-
9391
// Returns the file descriptors for the underlying ring buffer.
9492
std::tuple<int, int> fds() const;
9593

@@ -110,10 +108,9 @@ class Reactor final {
110108
~Reactor();
111109

112110
private:
113-
Fd headerFd_;
114-
Fd dataFd_;
115-
optional<util::ringbuffer::Consumer> consumer_;
116-
optional<util::ringbuffer::Producer> producer_;
111+
util::shm::Segment headerSegment_;
112+
util::shm::Segment dataSegment_;
113+
util::ringbuffer::RingBuffer rb_;
117114

118115
std::mutex mutex_;
119116
std::thread thread_;
@@ -162,7 +159,9 @@ class Reactor final {
162159
void run(TToken token);
163160

164161
private:
165-
util::ringbuffer::Producer producer_;
162+
util::shm::Segment headerSegment_;
163+
util::shm::Segment dataSegment_;
164+
util::ringbuffer::RingBuffer rb_;
166165
};
167166
};
168167

‎tensorpipe/util/ringbuffer/consumer.h

+14-5
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,21 @@ namespace ringbuffer {
1919
///
2020
/// Provides methods to read data from a ringbuffer.
2121
///
22-
class Consumer : public RingBufferWrapper {
22+
class Consumer {
2323
public:
24-
// Use base class constructor.
25-
using RingBufferWrapper::RingBufferWrapper;
24+
Consumer() = delete;
25+
26+
Consumer(RingBuffer& rb) : header_{rb.getHeader()}, data_{rb.getData()} {
27+
TP_THROW_IF_NULLPTR(data_);
28+
}
2629

2730
Consumer(const Consumer&) = delete;
2831
Consumer(Consumer&&) = delete;
2932

30-
virtual ~Consumer() noexcept {
33+
Consumer& operator=(const Consumer&) = delete;
34+
Consumer& operator=(Consumer&&) = delete;
35+
36+
~Consumer() noexcept {
3137
TP_THROW_ASSERT_IF(inTx());
3238
}
3339

@@ -192,7 +198,10 @@ class Consumer : public RingBufferWrapper {
192198
return size;
193199
}
194200

195-
protected:
201+
private:
202+
RingBufferHeader& header_;
203+
const uint8_t* const data_;
204+
unsigned tx_size_ = 0;
196205
bool inTx_{false};
197206
};
198207

‎tensorpipe/util/ringbuffer/producer.h

+14-5
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,21 @@ namespace ringbuffer {
1919
///
2020
/// Provides methods to write data into a ringbuffer.
2121
///
22-
class Producer : public RingBufferWrapper {
22+
class Producer {
2323
public:
24-
// Use base class constructor.
25-
using RingBufferWrapper::RingBufferWrapper;
24+
Producer() = delete;
25+
26+
Producer(RingBuffer& rb) : header_{rb.getHeader()}, data_{rb.getData()} {
27+
TP_THROW_IF_NULLPTR(data_);
28+
}
2629

2730
Producer(const Producer&) = delete;
2831
Producer(Producer&&) = delete;
2932

30-
virtual ~Producer() noexcept {
33+
Producer& operator=(const Producer&) = delete;
34+
Producer& operator=(Producer&&) = delete;
35+
36+
~Producer() noexcept {
3137
TP_THROW_ASSERT_IF(inTx());
3238
}
3339

@@ -196,7 +202,10 @@ class Producer : public RingBufferWrapper {
196202
return size;
197203
}
198204

199-
protected:
205+
private:
206+
RingBufferHeader& header_;
207+
uint8_t* const data_;
208+
unsigned tx_size_ = 0;
200209
bool inTx_{false};
201210
};
202211

‎tensorpipe/util/ringbuffer/ringbuffer.h

+7-45
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,10 @@ class RingBufferHeader {
133133
///
134134
class RingBuffer final {
135135
public:
136-
RingBuffer(const RingBuffer&) = delete;
137-
RingBuffer(RingBuffer&&) = delete;
136+
RingBuffer() = default;
138137

139-
RingBuffer(
140-
std::shared_ptr<RingBufferHeader> header,
141-
std::shared_ptr<uint8_t> data)
142-
: header_{std::move(header)}, data_{std::move(data)} {
138+
RingBuffer(RingBufferHeader* header, uint8_t* data)
139+
: header_(header), data_(data) {
143140
TP_THROW_IF_NULLPTR(header_) << "Header cannot be nullptr";
144141
TP_THROW_IF_NULLPTR(data_) << "Data cannot be nullptr";
145142
}
@@ -153,51 +150,16 @@ class RingBuffer final {
153150
}
154151

155152
const uint8_t* getData() const {
156-
return data_.get();
153+
return data_;
157154
}
158155

159156
uint8_t* getData() {
160-
return data_.get();
157+
return data_;
161158
}
162159

163160
protected:
164-
std::shared_ptr<RingBufferHeader> header_;
165-
166-
// Note: this is a std::shared_ptr<uint8_t[]> semantically.
167-
// A shared_ptr with array type is supported in C++17 and higher.
168-
std::shared_ptr<uint8_t> data_;
169-
};
170-
171-
///
172-
/// Ringbuffer wrapper
173-
///
174-
class RingBufferWrapper {
175-
public:
176-
RingBufferWrapper(const RingBufferWrapper&) = delete;
177-
178-
RingBufferWrapper(std::shared_ptr<RingBuffer> rb)
179-
: rb_{rb}, header_{rb->getHeader()}, data_{rb->getData()} {
180-
TP_THROW_IF_NULLPTR(rb);
181-
TP_THROW_IF_NULLPTR(data_);
182-
}
183-
184-
auto& getHeader() {
185-
return header_;
186-
}
187-
188-
const auto& getHeader() const {
189-
return header_;
190-
}
191-
192-
auto getRingBuffer() {
193-
return rb_;
194-
}
195-
196-
protected:
197-
std::shared_ptr<RingBuffer> rb_;
198-
RingBufferHeader& header_;
199-
uint8_t* const data_;
200-
unsigned tx_size_ = 0;
161+
RingBufferHeader* header_ = nullptr;
162+
uint8_t* data_ = nullptr;
201163
};
202164

203165
} // namespace ringbuffer

‎tensorpipe/util/ringbuffer/shm.cc

+30-32
Original file line numberDiff line numberDiff line change
@@ -13,57 +13,55 @@ namespace util {
1313
namespace ringbuffer {
1414
namespace shm {
1515

16-
std::tuple<int, int, std::shared_ptr<RingBuffer>> create(
16+
std::tuple<util::shm::Segment, util::shm::Segment, RingBuffer> create(
1717
size_t min_rb_byte_size,
18-
optional<tensorpipe::util::shm::PageType> data_page_type,
18+
optional<util::shm::PageType> data_page_type,
1919
bool perm_write) {
20-
std::shared_ptr<RingBufferHeader> header;
21-
std::shared_ptr<tensorpipe::util::shm::Segment> header_segment;
22-
std::tie(header, header_segment) =
23-
tensorpipe::util::shm::Segment::create<RingBufferHeader>(
24-
perm_write,
25-
tensorpipe::util::shm::PageType::Default,
26-
min_rb_byte_size);
20+
util::shm::Segment header_segment;
21+
RingBufferHeader* header;
22+
std::tie(header_segment, header) =
23+
util::shm::Segment::create<RingBufferHeader>(
24+
perm_write, util::shm::PageType::Default, min_rb_byte_size);
2725

28-
std::shared_ptr<uint8_t> data;
29-
std::shared_ptr<tensorpipe::util::shm::Segment> data_segment;
30-
std::tie(data, data_segment) =
31-
tensorpipe::util::shm::Segment::create<uint8_t[]>(
32-
header->kDataPoolByteSize, perm_write, data_page_type);
26+
util::shm::Segment data_segment;
27+
uint8_t* data;
28+
std::tie(data_segment, data) = util::shm::Segment::create<uint8_t[]>(
29+
header->kDataPoolByteSize, perm_write, data_page_type);
3330

3431
// Note: cannot use implicit construction from initializer list on GCC 5.5:
3532
// "converting to XYZ from initializer list would use explicit constructor".
3633
return std::make_tuple(
37-
header_segment->getFd(),
38-
data_segment->getFd(),
39-
std::make_shared<RingBuffer>(std::move(header), std::move(data)));
34+
std::move(header_segment),
35+
std::move(data_segment),
36+
RingBuffer(header, data));
4037
}
4138

42-
std::shared_ptr<RingBuffer> load(
39+
std::tuple<util::shm::Segment, util::shm::Segment, RingBuffer> load(
4340
int header_fd,
4441
int data_fd,
45-
optional<tensorpipe::util::shm::PageType> data_page_type,
42+
optional<util::shm::PageType> data_page_type,
4643
bool perm_write) {
47-
std::shared_ptr<RingBufferHeader> header;
48-
std::shared_ptr<tensorpipe::util::shm::Segment> header_segment;
49-
std::tie(header, header_segment) =
50-
tensorpipe::util::shm::Segment::load<RingBufferHeader>(
51-
header_fd, perm_write, tensorpipe::util::shm::PageType::Default);
44+
util::shm::Segment header_segment;
45+
RingBufferHeader* header;
46+
std::tie(header_segment, header) = util::shm::Segment::load<RingBufferHeader>(
47+
header_fd, perm_write, util::shm::PageType::Default);
5248
constexpr auto kHeaderSize = sizeof(RingBufferHeader);
53-
if (unlikely(kHeaderSize != header_segment->getSize())) {
49+
if (unlikely(kHeaderSize != header_segment.getSize())) {
5450
TP_THROW_SYSTEM(EPERM) << "Header segment of unexpected size";
5551
}
5652

57-
std::shared_ptr<uint8_t> data;
58-
std::shared_ptr<tensorpipe::util::shm::Segment> data_segment;
59-
std::tie(data, data_segment) =
60-
tensorpipe::util::shm::Segment::load<uint8_t[]>(
61-
data_fd, perm_write, data_page_type);
62-
if (unlikely(header->kDataPoolByteSize != data_segment->getSize())) {
53+
util::shm::Segment data_segment;
54+
uint8_t* data;
55+
std::tie(data_segment, data) =
56+
util::shm::Segment::load<uint8_t[]>(data_fd, perm_write, data_page_type);
57+
if (unlikely(header->kDataPoolByteSize != data_segment.getSize())) {
6358
TP_THROW_SYSTEM(EPERM) << "Data segment of unexpected size";
6459
}
6560

66-
return std::make_shared<RingBuffer>(std::move(header), std::move(data));
61+
return std::make_tuple(
62+
std::move(header_segment),
63+
std::move(data_segment),
64+
RingBuffer(header, data));
6765
}
6866

6967
} // namespace shm

‎tensorpipe/util/ringbuffer/shm.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace shm {
1818

1919
/// Creates ringbuffer on shared memory.
2020
///
21-
/// RingBuffer's data can have any <tensorpipe::util::shm::PageType>
21+
/// RingBuffer's data can have any <util::shm::PageType>
2222
/// (e.g. 4KB or a HugeTLB Page of 2MB or 1GB). If <data_page_type> is not
2323
/// provided, then choose the largest page that would result in
2424
/// close to full occupancy.
@@ -29,15 +29,15 @@ namespace shm {
2929
/// <min_rb_byte_size> is the minimum size of the data section
3030
/// of a RingBuffer (or each CPU's RingBuffer).
3131
///
32-
std::tuple<int, int, std::shared_ptr<RingBuffer>> create(
32+
std::tuple<util::shm::Segment, util::shm::Segment, RingBuffer> create(
3333
size_t min_rb_byte_size,
34-
optional<tensorpipe::util::shm::PageType> data_page_type = nullopt,
34+
optional<util::shm::PageType> data_page_type = nullopt,
3535
bool perm_write = true);
3636

37-
std::shared_ptr<RingBuffer> load(
37+
std::tuple<util::shm::Segment, util::shm::Segment, RingBuffer> load(
3838
int header_fd,
3939
int data_fd,
40-
optional<tensorpipe::util::shm::PageType> data_page_type = nullopt,
40+
optional<util::shm::PageType> data_page_type = nullopt,
4141
bool perm_write = true);
4242

4343
} // namespace shm

‎tensorpipe/util/shm/segment.cc

+24-9
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,21 @@ update to obtain the latest correctness checks."
8282
// Mention static constexpr char to export the symbol.
8383
constexpr char Segment::kBasePath[];
8484

85+
Segment::Segment(Segment&& other) noexcept {
86+
std::swap(page_type_, other.page_type_);
87+
std::swap(fd_, other.fd_);
88+
std::swap(byte_size_, other.byte_size_);
89+
std::swap(base_ptr_, other.base_ptr_);
90+
}
91+
92+
Segment& Segment::operator=(Segment&& other) noexcept {
93+
std::swap(page_type_, other.page_type_);
94+
std::swap(fd_, other.fd_);
95+
std::swap(byte_size_, other.byte_size_);
96+
std::swap(base_ptr_, other.base_ptr_);
97+
return *this;
98+
}
99+
85100
Segment::Segment(
86101
size_t byte_size,
87102
bool perm_write,
@@ -118,17 +133,17 @@ void Segment::mmap(bool perm_write, optional<PageType> page_type) {
118133
}
119134

120135
Segment::~Segment() {
121-
int ret = munmap(base_ptr_, byte_size_);
122-
if (ret == -1) {
123-
TP_LOG_ERROR() << "Error while munmapping shared memory segment. Error: "
124-
<< toErrorCode(errno).message();
136+
if (base_ptr_ != nullptr) {
137+
int ret = munmap(base_ptr_, byte_size_);
138+
if (ret == -1) {
139+
TP_LOG_ERROR() << "Error while munmapping shared memory segment. Error: "
140+
<< toErrorCode(errno).message();
141+
}
125142
}
126-
if (0 > fd_) {
127-
TP_LOG_ERROR() << "Attempt to destroy segment with negative file "
128-
<< "descriptor";
129-
return;
143+
if (fd_ >= 0) {
144+
int ret = ::close(fd_);
145+
TP_THROW_SYSTEM_IF(ret != 0, errno);
130146
}
131-
::close(fd_);
132147
}
133148

134149
} // namespace shm

‎tensorpipe/util/shm/segment.h

+40-43
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,25 @@ class Segment {
6161
// Default base path for all segments created.
6262
static constexpr char kBasePath[] = "/dev/shm";
6363

64+
Segment() = default;
65+
6466
Segment(const Segment&) = delete;
65-
Segment(Segment&&) = delete;
67+
Segment(Segment&&) noexcept;
68+
69+
Segment& operator=(const Segment& other) = delete;
70+
Segment& operator=(Segment&& other) noexcept;
6671

6772
Segment(size_t byte_size, bool perm_write, optional<PageType> page_type);
6873

6974
Segment(int fd, bool perm_write, optional<PageType> page_type);
7075

71-
/// Create read and size shared memory to contain an object of class T.
76+
/// Allocate shared memory to contain an object of type T and construct it.
7277
///
73-
/// The created object's shared_ptr will own the lifetime of the
74-
/// Segment and will call Segment destructor.
75-
/// Caller can use the shared_ptr to the underlying Segment.
78+
/// The Segment object owns the memory and frees it when destructed.
79+
/// The raw pointer to the object provides a view into the Segment but doesn't
80+
/// own it and may thus become invalid if the Segment isn't kept alive.
7681
template <class T, class... Args>
77-
static std::pair<std::shared_ptr<T>, std::shared_ptr<Segment>> create(
82+
static std::pair<Segment, T*> create(
7883
bool perm_write,
7984
optional<PageType> page_type,
8085
Args&&... args) {
@@ -84,25 +89,24 @@ class Segment {
8489
static_assert(std::is_trivially_copyable<T>::value, "!");
8590

8691
const auto byte_size = sizeof(T);
87-
auto segment = std::make_shared<Segment>(byte_size, perm_write, page_type);
88-
TP_DCHECK_EQ(segment->getSize(), byte_size);
92+
Segment segment(byte_size, perm_write, page_type);
93+
TP_DCHECK_EQ(segment.getSize(), byte_size);
8994

9095
// Initialize in place. Forward T's constructor arguments.
91-
T* ptr = new (segment->getPtr()) T(std::forward<Args>(args)...);
92-
if (ptr != segment->getPtr()) {
96+
T* ptr = new (segment.getPtr()) T(std::forward<Args>(args)...);
97+
if (ptr != segment.getPtr()) {
9398
TP_THROW_SYSTEM(EPERM)
94-
<< "new's address cannot be different from segment->getPtr() "
99+
<< "new's address cannot be different from segment.getPtr() "
95100
<< " address. Some aligment assumption was incorrect";
96101
}
97102

98-
return {std::shared_ptr<T>(segment, ptr), segment};
103+
return {std::move(segment), ptr};
99104
}
100105

101106
/// One-dimensional array version of create<T, ...Args>.
102-
/// Caller can use the shared_ptr to the underlying Segment.
103107
// XXX: Fuse all versions of create.
104108
template <class T, typename TScalar = typename std::remove_extent<T>::type>
105-
static std::pair<std::shared_ptr<TScalar>, std::shared_ptr<Segment>> create(
109+
static std::pair<Segment, TScalar*> create(
106110
size_t num_elements,
107111
bool perm_write,
108112
optional<PageType> page_type) {
@@ -125,55 +129,48 @@ class Segment {
125129
static_assert(std::is_same<TScalar[], T>::value, "Type mismatch");
126130

127131
size_t byte_size = sizeof(TScalar) * num_elements;
128-
auto segment = std::make_shared<Segment>(byte_size, perm_write, page_type);
129-
TP_DCHECK_EQ(segment->getSize(), byte_size);
132+
Segment segment(byte_size, perm_write, page_type);
133+
TP_DCHECK_EQ(segment.getSize(), byte_size);
130134

131135
// Initialize in place.
132-
TScalar* ptr = new (segment->getPtr()) TScalar[num_elements]();
133-
if (ptr != segment->getPtr()) {
136+
TScalar* ptr = new (segment.getPtr()) TScalar[num_elements]();
137+
if (ptr != segment.getPtr()) {
134138
TP_THROW_SYSTEM(EPERM)
135-
<< "new's address cannot be different from segment->getPtr() "
139+
<< "new's address cannot be different from segment.getPtr() "
136140
<< " address. Some aligment assumption was incorrect";
137141
}
138142

139-
return {std::shared_ptr<TScalar>(segment, ptr), segment};
143+
return {std::move(segment), ptr};
140144
}
141145

142-
/// Load an already created shared memory Segment that holds an
143-
/// object of type T, where T is an array type.
144-
///
145-
/// Lifecycle of shared_ptr and Segment's reference_wrapper is
146-
/// identical to create<>().
146+
/// Load an existing shared memory region that already holds an object of type
147+
/// T, where T is an array type.
147148
template <
148149
class T,
149150
typename TScalar = typename std::remove_extent<T>::type,
150151
std::enable_if_t<std::is_array<T>::value, int> = 0>
151-
static std::pair<std::shared_ptr<TScalar>, std::shared_ptr<Segment>> load(
152+
static std::pair<Segment, TScalar*> load(
152153
int fd,
153154
bool perm_write,
154155
optional<PageType> page_type) {
155-
auto segment = std::make_shared<Segment>(fd, perm_write, page_type);
156-
const size_t size = segment->getSize();
156+
Segment segment(fd, perm_write, page_type);
157157
static_assert(
158158
std::rank<T>::value == 1,
159159
"Currently only rank one arrays are supported");
160160
static_assert(std::is_trivially_copyable<TScalar>::value, "!");
161-
auto ptr = static_cast<TScalar*>(segment->getPtr());
162-
return {std::shared_ptr<TScalar>(segment, ptr), segment};
161+
auto ptr = static_cast<TScalar*>(segment.getPtr());
162+
return {std::move(segment), ptr};
163163
}
164164

165-
/// Load an already created shared memory Segment that holds an
166-
/// object of type T, where T is NOT an array type.
167-
///
168-
/// Lifecycle of shared_ptr and Segment's reference_wrapper is
169-
/// identical to create<>().
165+
/// Load an existing shared memory region that already holds an object of type
166+
/// T, where T is NOT an array type.
170167
template <class T, std::enable_if_t<!std::is_array<T>::value, int> = 0>
171-
static std::pair<std::shared_ptr<T>, std::shared_ptr<Segment>> load(
168+
static std::pair<Segment, T*> load(
172169
int fd,
173170
bool perm_write,
174171
optional<PageType> page_type) {
175-
auto segment = std::make_shared<Segment>(fd, perm_write, page_type);
176-
const size_t size = segment->getSize();
172+
Segment segment(fd, perm_write, page_type);
173+
const size_t size = segment.getSize();
177174
// XXX: Do some checking other than the size that we are loading
178175
// the right type.
179176
if (size != sizeof(T)) {
@@ -184,8 +181,8 @@ class Segment {
184181
<< "consider linking segment after it has been fully initialized.";
185182
}
186183
static_assert(std::is_trivially_copyable<T>::value, "!");
187-
auto ptr = static_cast<T*>(segment->getPtr());
188-
return {std::shared_ptr<T>(segment, ptr), segment};
184+
auto ptr = static_cast<T*>(segment.getPtr());
185+
return {std::move(segment), ptr};
189186
}
190187

191188
const int getFd() const {
@@ -212,16 +209,16 @@ class Segment {
212209

213210
protected:
214211
// The page used to mmap the segment.
215-
PageType page_type_;
212+
PageType page_type_ = PageType::Default;
216213

217214
// The file descriptor of the shared memory file.
218215
int fd_ = -1;
219216

220217
// Byte size of shared memory segment.
221-
size_t byte_size_;
218+
size_t byte_size_ = 0;
222219

223220
// Base pointer of mmmap'ed shared memory segment.
224-
void* base_ptr_;
221+
void* base_ptr_ = nullptr;
225222

226223
void mmap(bool perm_write, optional<PageType> page_type);
227224
};

0 commit comments

Comments
 (0)
This repository has been archived.