Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 4f6e241

Browse files
lwfacebook-github-bot
authored andcommittedSep 4, 2020
Replace protobuf with libnop in connections
Summary: This is the big change, where we swap out protobuf for libnop in the biggest usecases, namely the pipe and the channels. Luckily all these changes are identical so although this consists in many lines the complexity of this change is limited. What we do is: - change the interface of connections to provide a specialization of read/write for NopHolders, rather than for protobuf Messages. - convert all the proto definitions to C++ structs, defined in new nop_types.h files. - go through all usages of the protobufs and all callsites of read/write and update them. This change leaves behind some proto defs which are still being used by some unit tests, as these tests will be ported in later commits and those protos will be removed then. It also doesn't reimplement yet the specialized codepath that the SHM has for protobuf messages: it is left there, albeit "cut off", which means that SHM will fall back to the generic implementation and be slower. This will be temporary as in the next commit we'll port that specialized code and re-attach it. Reviewed By: heiner Differential Revision: D22763737 fbshipit-source-id: d44be13e827822662a0b0395aff05ece478dbc51
1 parent ab9fb4b commit 4f6e241

File tree

14 files changed

+464
-360
lines changed

14 files changed

+464
-360
lines changed
 

‎tensorpipe/CMakeLists.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ endif()
6060

6161
target_sources(tensorpipe PRIVATE
6262
channel/mpt/channel.cc
63-
channel/mpt/context.cc
64-
proto/channel/mpt.proto)
63+
channel/mpt/context.cc)
6564

6665

6766
## Transports

‎tensorpipe/channel/mpt/channel.cc

+42-35
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313

1414
#include <tensorpipe/channel/error.h>
1515
#include <tensorpipe/channel/helpers.h>
16+
#include <tensorpipe/channel/mpt/nop_types.h>
1617
#include <tensorpipe/common/callback.h>
1718
#include <tensorpipe/common/defs.h>
1819
#include <tensorpipe/common/error_macros.h>
19-
#include <tensorpipe/proto/channel/mpt.pb.h>
2020
#include <tensorpipe/transport/context.h>
2121
#include <tensorpipe/transport/error.h>
2222
#include <tensorpipe/transport/listener.h>
@@ -103,7 +103,7 @@ class Channel::Impl : public std::enable_shared_from_this<Channel::Impl> {
103103
void closeFromLoop_();
104104

105105
// Called when client reads the server's hello on backbone connection
106-
void onClientReadHelloOnConnection_(const proto::Packet& pbPacketIn);
106+
void onClientReadHelloOnConnection_(const Packet& nopPacketIn);
107107

108108
// Called when server accepts new client connection for lane
109109
void onServerAcceptOfLane_(
@@ -209,23 +209,27 @@ void Channel::Impl::initFromLoop_() {
209209
TP_DCHECK_EQ(state_, UNINITIALIZED);
210210
if (endpoint_ == Endpoint::kConnect) {
211211
state_ = CLIENT_READING_HELLO;
212-
auto packet = std::make_shared<proto::Packet>();
213-
TP_VLOG(6) << "Channel " << id_ << " reading proto (server hello)";
214-
connection_->read(*packet, lazyCallbackWrapper_([packet](Impl& impl) {
215-
TP_VLOG(6) << "Channel " << impl.id_
216-
<< " done reading proto (server hello)";
217-
impl.onClientReadHelloOnConnection_(*packet);
218-
}));
212+
auto nopHolderIn = std::make_shared<NopHolder<Packet>>();
213+
TP_VLOG(6) << "Channel " << id_ << " reading nop object (server hello)";
214+
connection_->read(
215+
*nopHolderIn, lazyCallbackWrapper_([nopHolderIn](Impl& impl) {
216+
TP_VLOG(6) << "Channel " << impl.id_
217+
<< " done reading nop object (server hello)";
218+
impl.onClientReadHelloOnConnection_(nopHolderIn->getObject());
219+
}));
219220
} else if (endpoint_ == Endpoint::kListen) {
220221
state_ = SERVER_ACCEPTING_LANES;
221222
const std::vector<std::string>& addresses = context_->addresses();
222223
TP_DCHECK_EQ(addresses.size(), numLanes_);
223-
auto packet = std::make_shared<proto::Packet>();
224-
proto::ServerHello* pbServerHello = packet->mutable_server_hello();
224+
auto nopHolderOut = std::make_shared<NopHolder<Packet>>();
225+
Packet& nopPacket = nopHolderOut->getObject();
226+
nopPacket.Become(nopPacket.index_of<ServerHello>());
227+
ServerHello& nopServerHello = *nopPacket.get<ServerHello>();
225228
for (uint64_t laneIdx = 0; laneIdx < numLanes_; ++laneIdx) {
226-
proto::LaneAdvertisement* pbLaneAdvertisement =
227-
pbServerHello->add_lane_advertisements();
228-
pbLaneAdvertisement->set_address(addresses[laneIdx]);
229+
nopServerHello.laneAdvertisements.emplace_back();
230+
LaneAdvertisement& nopLaneAdvertisement =
231+
nopServerHello.laneAdvertisements.back();
232+
nopLaneAdvertisement.address = addresses[laneIdx];
229233
TP_VLOG(6) << "Channel " << id_ << " requesting connection (for lane "
230234
<< laneIdx << ")";
231235
uint64_t token = context_->registerConnectionRequest(
@@ -240,14 +244,15 @@ void Channel::Impl::initFromLoop_() {
240244
impl.onServerAcceptOfLane_(laneIdx, std::move(connection));
241245
}));
242246
laneRegistrationIds_.emplace(laneIdx, token);
243-
pbLaneAdvertisement->set_registration_id(token);
247+
nopLaneAdvertisement.registrationId = token;
244248
numLanesBeingAccepted_++;
245249
}
246-
TP_VLOG(6) << "Channel " << id_ << " writing proto (server hello)";
247-
connection_->write(*packet, lazyCallbackWrapper_([packet](Impl& impl) {
248-
TP_VLOG(6) << "Channel " << impl.id_
249-
<< " done writing proto (server hello)";
250-
}));
250+
TP_VLOG(6) << "Channel " << id_ << " writing nop object (server hello)";
251+
connection_->write(
252+
*nopHolderOut, lazyCallbackWrapper_([nopHolderOut](Impl& impl) {
253+
TP_VLOG(6) << "Channel " << impl.id_
254+
<< " done writing nop object (server hello)";
255+
}));
251256
} else {
252257
TP_THROW_ASSERT() << "unknown endpoint";
253258
}
@@ -406,29 +411,31 @@ void Channel::Impl::setIdFromLoop_(std::string id) {
406411
id_ = std::move(id);
407412
}
408413

409-
void Channel::Impl::onClientReadHelloOnConnection_(
410-
const proto::Packet& pbPacketIn) {
414+
void Channel::Impl::onClientReadHelloOnConnection_(const Packet& nopPacketIn) {
411415
TP_DCHECK(loop_.inLoop());
412416
TP_DCHECK_EQ(state_, CLIENT_READING_HELLO);
413-
TP_DCHECK_EQ(pbPacketIn.type_case(), proto::Packet::kServerHello);
417+
TP_DCHECK_EQ(nopPacketIn.index(), nopPacketIn.index_of<ServerHello>());
414418

415-
const proto::ServerHello& pbServerHello = pbPacketIn.server_hello();
416-
TP_DCHECK_EQ(pbServerHello.lane_advertisements().size(), numLanes_);
419+
const ServerHello& nopServerHello = *nopPacketIn.get<ServerHello>();
420+
TP_DCHECK_EQ(nopServerHello.laneAdvertisements.size(), numLanes_);
417421
lanes_.resize(numLanes_);
418422
for (uint64_t laneIdx = 0; laneIdx < numLanes_; ++laneIdx) {
419-
const proto::LaneAdvertisement& pbLaneAdvertisement =
420-
pbServerHello.lane_advertisements().Get(laneIdx);
423+
const LaneAdvertisement& nopLaneAdvertisement =
424+
nopServerHello.laneAdvertisements[laneIdx];
421425
std::shared_ptr<transport::Connection> lane =
422-
context_->connect(laneIdx, pbLaneAdvertisement.address());
423-
auto pbPacketOut = std::make_shared<proto::Packet>();
424-
proto::ClientHello* pbClientHello = pbPacketOut->mutable_client_hello();
425-
pbClientHello->set_registration_id(pbLaneAdvertisement.registration_id());
426-
TP_VLOG(6) << "Channel " << id_ << " writing proto (client hello) on lane "
427-
<< laneIdx;
426+
context_->connect(laneIdx, nopLaneAdvertisement.address);
427+
auto nopHolderOut = std::make_shared<NopHolder<Packet>>();
428+
Packet& nopPacket = nopHolderOut->getObject();
429+
nopPacket.Become(nopPacket.index_of<ClientHello>());
430+
ClientHello& nopClientHello = *nopPacket.get<ClientHello>();
431+
nopClientHello.registrationId = nopLaneAdvertisement.registrationId;
432+
TP_VLOG(6) << "Channel " << id_
433+
<< " writing nop object (client hello) on lane " << laneIdx;
428434
lane->write(
429-
*pbPacketOut, lazyCallbackWrapper_([laneIdx, pbPacketOut](Impl& impl) {
435+
*nopHolderOut,
436+
lazyCallbackWrapper_([laneIdx, nopHolderOut](Impl& impl) {
430437
TP_VLOG(6) << "Channel " << impl.id_
431-
<< " done writing proto (client hello) on lane "
438+
<< " done writing nop object (client hello) on lane "
432439
<< laneIdx;
433440
}));
434441
lanes_[laneIdx] = std::move(lane);

‎tensorpipe/channel/mpt/context.cc

+24-26
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
#include <tensorpipe/channel/error.h>
1818
#include <tensorpipe/channel/helpers.h>
1919
#include <tensorpipe/channel/mpt/channel.h>
20+
#include <tensorpipe/channel/mpt/nop_types.h>
2021
#include <tensorpipe/channel/registry.h>
2122
#include <tensorpipe/common/callback.h>
2223
#include <tensorpipe/common/defs.h>
2324
#include <tensorpipe/common/error_macros.h>
24-
#include <tensorpipe/proto/channel/mpt.pb.h>
2525
#include <tensorpipe/transport/context.h>
2626
#include <tensorpipe/transport/error.h>
2727
#include <tensorpipe/transport/listener.h>
@@ -96,7 +96,7 @@ class Context::Impl : public Context::PrivateIface,
9696
void onAcceptOfLane_(std::shared_ptr<transport::Connection>);
9797
void onReadClientHelloOnLane_(
9898
std::shared_ptr<transport::Connection>,
99-
const proto::Packet&);
99+
const Packet&);
100100

101101
void setError_(Error error);
102102

@@ -157,8 +157,8 @@ Context::Impl::Impl(
157157
TP_THROW_ASSERT_IF(contexts_.size() != listeners_.size());
158158
numLanes_ = contexts_.size();
159159
// FIXME Escape the contexts' domain descriptors in case they contain a colon?
160-
// Or put them all in a protobuf, that'll do the escaping for us.
161-
// But is it okay to compare protobufs by equality bitwise?
160+
// Or put them all in a nop object, that'll do the escaping for us.
161+
// But is it okay to compare nop objects by equality bitwise?
162162
std::ostringstream ss;
163163
ss << contexts_.size();
164164
for (const auto& context : contexts_) {
@@ -308,42 +308,40 @@ void Context::Impl::onAcceptOfLane_(
308308

309309
// Keep it alive until we figure out what to do with it.
310310
connectionsWaitingForHello_.insert(connection);
311-
auto pbPacketIn = std::make_shared<proto::Packet>();
312-
TP_VLOG(6) << "Channel context " << id_ << " reading proto (client hello)";
311+
auto npHolderIn = std::make_shared<NopHolder<Packet>>();
312+
TP_VLOG(6) << "Channel context " << id_
313+
<< " reading nop object (client hello)";
313314
connection->read(
314-
*pbPacketIn,
315-
lazyCallbackWrapper_([pbPacketIn,
315+
*npHolderIn,
316+
lazyCallbackWrapper_([npHolderIn,
316317
weakConnection{std::weak_ptr<transport::Connection>(
317318
connection)}](Impl& impl) mutable {
318319
TP_VLOG(6) << "Channel context " << impl.id_
319-
<< " done reading proto (client hello)";
320+
<< " done reading nop object (client hello)";
320321
std::shared_ptr<transport::Connection> connection =
321322
weakConnection.lock();
322323
TP_DCHECK(connection);
323324
impl.connectionsWaitingForHello_.erase(connection);
324-
impl.onReadClientHelloOnLane_(std::move(connection), *pbPacketIn);
325+
impl.onReadClientHelloOnLane_(
326+
std::move(connection), npHolderIn->getObject());
325327
}));
326328
}
327329

328330
void Context::Impl::onReadClientHelloOnLane_(
329331
std::shared_ptr<transport::Connection> connection,
330-
const proto::Packet& pbPacketIn) {
332+
const Packet& nopPacketIn) {
331333
TP_DCHECK(loop_.inLoop());
332-
333-
if (pbPacketIn.has_client_hello()) {
334-
const proto::ClientHello& pbClientHello = pbPacketIn.client_hello();
335-
uint64_t registrationId = pbClientHello.registration_id();
336-
auto iter = connectionRequestRegistrations_.find(registrationId);
337-
// The connection request may have already been deregistered, for example
338-
// because the channel may have been closed.
339-
if (iter != connectionRequestRegistrations_.end()) {
340-
auto fn = std::move(iter->second);
341-
connectionRequestRegistrations_.erase(iter);
342-
fn(Error::kSuccess, std::move(connection));
343-
}
344-
} else {
345-
TP_LOG_ERROR() << "packet contained unknown content: "
346-
<< pbPacketIn.type_case();
334+
TP_DCHECK_EQ(nopPacketIn.index(), nopPacketIn.index_of<ClientHello>());
335+
336+
const ClientHello& nopClientHello = *nopPacketIn.get<ClientHello>();
337+
uint64_t registrationId = nopClientHello.registrationId;
338+
auto iter = connectionRequestRegistrations_.find(registrationId);
339+
// The connection request may have already been deregistered, for example
340+
// because the channel may have been closed.
341+
if (iter != connectionRequestRegistrations_.end()) {
342+
auto fn = std::move(iter->second);
343+
connectionRequestRegistrations_.erase(iter);
344+
fn(Error::kSuccess, std::move(connection));
347345
}
348346
}
349347

‎tensorpipe/channel/mpt/nop_types.h

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <string>
12+
#include <vector>
13+
14+
#include <nop/serializer.h>
15+
#include <nop/structure.h>
16+
#include <nop/types/variant.h>
17+
18+
namespace tensorpipe {
19+
namespace channel {
20+
namespace mpt {
21+
22+
struct LaneAdvertisement {
23+
std::string address;
24+
uint64_t registrationId;
25+
NOP_STRUCTURE(LaneAdvertisement, address, registrationId);
26+
};
27+
28+
struct ServerHello {
29+
std::vector<LaneAdvertisement> laneAdvertisements;
30+
NOP_STRUCTURE(ServerHello, laneAdvertisements);
31+
};
32+
33+
struct ClientHello {
34+
uint64_t registrationId;
35+
NOP_STRUCTURE(ClientHello, registrationId);
36+
};
37+
38+
using Packet = nop::Variant<ServerHello, ClientHello>;
39+
40+
} // namespace mpt
41+
} // namespace channel
42+
} // namespace tensorpipe

‎tensorpipe/core/listener.cc

+20-19
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#include <tensorpipe/common/error_macros.h>
1818
#include <tensorpipe/common/optional.h>
1919
#include <tensorpipe/core/error.h>
20+
#include <tensorpipe/core/nop_types.h>
2021
#include <tensorpipe/core/pipe.h>
21-
#include <tensorpipe/proto/core.pb.h>
2222
#include <tensorpipe/transport/connection.h>
2323
#include <tensorpipe/transport/listener.h>
2424

@@ -141,7 +141,7 @@ class Listener::Impl : public Listener::PrivateIface,
141141
void onConnectionHelloRead_(
142142
std::string,
143143
std::shared_ptr<transport::Connection>,
144-
const proto::Packet&);
144+
const Packet&);
145145

146146
template <typename T>
147147
friend class LazyCallbackWrapper;
@@ -386,24 +386,26 @@ void Listener::Impl::onAccept_(
386386
TP_DCHECK(loop_.inLoop());
387387
// Keep it alive until we figure out what to do with it.
388388
connectionsWaitingForHello_.insert(connection);
389-
auto pbPacketIn = std::make_shared<proto::Packet>();
389+
auto nopHolderIn = std::make_shared<NopHolder<Packet>>();
390390
TP_VLOG(3) << "Listener " << id_
391-
<< " is reading proto (spontaneous or requested connection)";
391+
<< " is reading nop object (spontaneous or requested connection)";
392392
connection->read(
393-
*pbPacketIn,
394-
lazyCallbackWrapper_([pbPacketIn,
393+
*nopHolderIn,
394+
lazyCallbackWrapper_([nopHolderIn,
395395
transport{std::move(transport)},
396396
weakConnection{std::weak_ptr<transport::Connection>(
397397
connection)}](Impl& impl) mutable {
398398
TP_VLOG(3)
399399
<< "Listener " << impl.id_
400-
<< " done reading proto (spontaneous or requested connection)";
400+
<< " done reading nop object (spontaneous or requested connection)";
401401
std::shared_ptr<transport::Connection> connection =
402402
weakConnection.lock();
403403
TP_DCHECK(connection);
404404
impl.connectionsWaitingForHello_.erase(connection);
405405
impl.onConnectionHelloRead_(
406-
std::move(transport), std::move(connection), *pbPacketIn);
406+
std::move(transport),
407+
std::move(connection),
408+
nopHolderIn->getObject());
407409
}));
408410
}
409411

@@ -429,16 +431,15 @@ void Listener::Impl::armListener_(std::string transport) {
429431
void Listener::Impl::onConnectionHelloRead_(
430432
std::string transport,
431433
std::shared_ptr<transport::Connection> connection,
432-
const proto::Packet& pbPacketIn) {
434+
const Packet& nopPacketIn) {
433435
TP_DCHECK(loop_.inLoop());
434-
if (pbPacketIn.has_spontaneous_connection()) {
435-
const proto::SpontaneousConnection& pbSpontaneousConnection =
436-
pbPacketIn.spontaneous_connection();
436+
if (nopPacketIn.is<SpontaneousConnection>()) {
437+
const SpontaneousConnection& nopSpontaneousConnection =
438+
*nopPacketIn.get<SpontaneousConnection>();
437439
TP_VLOG(3) << "Listener " << id_ << " got spontaneous connection";
438440
std::string pipeId = id_ + ".p" + std::to_string(pipeCounter_++);
439441
TP_VLOG(1) << "Listener " << id_ << " is opening pipe " << pipeId;
440-
const std::string& remoteContextName =
441-
pbSpontaneousConnection.context_name();
442+
const std::string& remoteContextName = nopSpontaneousConnection.contextName;
442443
if (remoteContextName != "") {
443444
std::string aliasPipeId = id_ + "_from_" + remoteContextName;
444445
TP_VLOG(1) << "Pipe " << pipeId << " aliased as " << aliasPipeId;
@@ -453,10 +454,10 @@ void Listener::Impl::onConnectionHelloRead_(
453454
std::move(transport),
454455
std::move(connection));
455456
acceptCallback_.trigger(Error::kSuccess, std::move(pipe));
456-
} else if (pbPacketIn.has_requested_connection()) {
457-
const proto::RequestedConnection& pbRequestedConnection =
458-
pbPacketIn.requested_connection();
459-
uint64_t registrationId = pbRequestedConnection.registration_id();
457+
} else if (nopPacketIn.is<RequestedConnection>()) {
458+
const RequestedConnection& nopRequestedConnection =
459+
*nopPacketIn.get<RequestedConnection>();
460+
uint64_t registrationId = nopRequestedConnection.registrationId;
460461
TP_VLOG(3) << "Listener " << id_ << " got requested connection (#"
461462
<< registrationId << ")";
462463
auto iter = connectionRequestRegistrations_.find(registrationId);
@@ -469,7 +470,7 @@ void Listener::Impl::onConnectionHelloRead_(
469470
}
470471
} else {
471472
TP_LOG_ERROR() << "packet contained unknown content: "
472-
<< pbPacketIn.type_case();
473+
<< nopPacketIn.index();
473474
}
474475
}
475476

‎tensorpipe/core/nop_types.h

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <string>
12+
#include <unordered_map>
13+
#include <vector>
14+
15+
#include <nop/serializer.h>
16+
#include <nop/structure.h>
17+
#include <nop/types/variant.h>
18+
19+
namespace tensorpipe {
20+
21+
struct SpontaneousConnection {
22+
std::string contextName;
23+
NOP_STRUCTURE(SpontaneousConnection, contextName);
24+
};
25+
26+
struct RequestedConnection {
27+
uint64_t registrationId;
28+
NOP_STRUCTURE(RequestedConnection, registrationId);
29+
};
30+
31+
struct TransportAdvertisement {
32+
std::string domainDescriptor;
33+
NOP_STRUCTURE(TransportAdvertisement, domainDescriptor);
34+
};
35+
36+
struct ChannelAdvertisement {
37+
std::string domainDescriptor;
38+
NOP_STRUCTURE(ChannelAdvertisement, domainDescriptor);
39+
};
40+
41+
struct Brochure {
42+
std::unordered_map<std::string, TransportAdvertisement>
43+
transportAdvertisement;
44+
std::unordered_map<std::string, ChannelAdvertisement> channelAdvertisement;
45+
NOP_STRUCTURE(Brochure, transportAdvertisement, channelAdvertisement);
46+
};
47+
48+
struct ChannelSelection {
49+
uint64_t registrationId;
50+
NOP_STRUCTURE(ChannelSelection, registrationId);
51+
};
52+
53+
struct BrochureAnswer {
54+
std::string transport;
55+
std::string address;
56+
uint64_t registrationId;
57+
std::unordered_map<std::string, ChannelSelection> channelSelection;
58+
NOP_STRUCTURE(
59+
BrochureAnswer,
60+
transport,
61+
address,
62+
registrationId,
63+
channelSelection);
64+
};
65+
66+
enum class DeviceType { DEVICE_TYPE_UNSPECIFIED, DEVICE_TYPE_CPU };
67+
68+
struct MessageDescriptor {
69+
struct PayloadDescriptor {
70+
int64_t sizeInBytes;
71+
std::string metadata;
72+
NOP_STRUCTURE(PayloadDescriptor, sizeInBytes, metadata);
73+
};
74+
75+
struct TensorDescriptor {
76+
int64_t sizeInBytes;
77+
std::string metadata;
78+
79+
DeviceType deviceType;
80+
std::string channelName;
81+
std::string channelDescriptor;
82+
NOP_STRUCTURE(
83+
TensorDescriptor,
84+
sizeInBytes,
85+
metadata,
86+
deviceType,
87+
channelName,
88+
channelDescriptor);
89+
};
90+
91+
std::string metadata;
92+
std::vector<PayloadDescriptor> payloadDescriptors;
93+
std::vector<TensorDescriptor> tensorDescriptors;
94+
NOP_STRUCTURE(
95+
MessageDescriptor,
96+
metadata,
97+
payloadDescriptors,
98+
tensorDescriptors);
99+
};
100+
101+
using Packet = nop::Variant<
102+
SpontaneousConnection,
103+
RequestedConnection,
104+
Brochure,
105+
BrochureAnswer,
106+
MessageDescriptor>;
107+
108+
} // namespace tensorpipe

‎tensorpipe/core/pipe.cc

+154-144
Large diffs are not rendered by default.

‎tensorpipe/proto/channel/mpt.proto

-31
This file was deleted.

‎tensorpipe/proto/core.proto

-47
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@ syntax = "proto3";
1010

1111
package tensorpipe.proto;
1212

13-
message SpontaneousConnection {
14-
string context_name = 1;
15-
}
16-
17-
message RequestedConnection {
18-
uint64 registration_id = 1;
19-
}
20-
2113
message TransportAdvertisement {
2214
string domain_descriptor = 1;
2315
}
@@ -31,49 +23,10 @@ message Brochure {
3123
map<string, ChannelAdvertisement> channel_advertisement = 2;
3224
}
3325

34-
message ChannelSelection {
35-
uint64 registration_id = 1;
36-
}
37-
38-
message BrochureAnswer {
39-
string transport = 1;
40-
string address = 2;
41-
uint64 registration_id = 3;
42-
43-
map<string, ChannelSelection> channel_selection = 4;
44-
}
45-
46-
enum DeviceType {
47-
DEVICE_TYPE_UNSPECIFIED = 0;
48-
DEVICE_TYPE_CPU = 1;
49-
}
50-
5126
message MessageDescriptor {
5227
message PayloadDescriptor {
5328
int64 size_in_bytes = 1;
54-
bytes metadata = 2;
5529
}
5630

57-
message TensorDescriptor {
58-
int64 size_in_bytes = 1;
59-
bytes metadata = 2;
60-
61-
DeviceType device_type = 3;
62-
string channel_name = 4;
63-
bytes channel_descriptor = 5;
64-
}
65-
66-
bytes metadata = 1;
6731
repeated PayloadDescriptor payload_descriptors = 2;
68-
repeated TensorDescriptor tensor_descriptors = 3;
69-
}
70-
71-
message Packet {
72-
oneof type {
73-
SpontaneousConnection spontaneous_connection = 1;
74-
RequestedConnection requested_connection = 2;
75-
Brochure brochure = 3;
76-
BrochureAnswer brochure_answer = 4;
77-
MessageDescriptor message_descriptor = 5;
78-
}
7932
}

‎tensorpipe/test/transport/connection_test.cc

+20-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
#include <array>
1212

13-
#include <tensorpipe/proto/core.pb.h>
13+
#include <nop/serializer.h>
14+
#include <nop/structure.h>
1415

1516
using namespace tensorpipe;
1617
using namespace tensorpipe::transport;
@@ -89,25 +90,34 @@ TEST_P(TransportTest, DISABLED_Connection_DestroyConnectionFromCallback) {
8990
});
9091
}
9192

92-
TEST_P(TransportTest, Connection_ProtobufWrite) {
93+
namespace {
94+
95+
struct MyNopType {
96+
uint32_t myIntField;
97+
NOP_STRUCTURE(MyNopType, myIntField);
98+
};
99+
100+
} // namespace
101+
102+
TEST_P(TransportTest, Connection_NopWrite) {
93103
constexpr size_t kSize = 0x42;
94104

95105
testConnection(
96106
[&](std::shared_ptr<Connection> conn) {
97-
auto message = std::make_shared<
98-
tensorpipe::proto::MessageDescriptor::PayloadDescriptor>();
99-
conn->read(*message, [&, conn, message](const Error& error) {
107+
auto holder = std::make_shared<NopHolder<MyNopType>>();
108+
MyNopType& object = holder->getObject();
109+
conn->read(*holder, [&, conn, holder](const Error& error) {
100110
ASSERT_FALSE(error) << error.what();
101-
ASSERT_EQ(message->size_in_bytes(), kSize);
111+
ASSERT_EQ(object.myIntField, kSize);
102112
peers_->done(PeerGroup::kServer);
103113
});
104114
peers_->join(PeerGroup::kServer);
105115
},
106116
[&](std::shared_ptr<Connection> conn) {
107-
auto message = std::make_shared<
108-
tensorpipe::proto::MessageDescriptor::PayloadDescriptor>();
109-
message->set_size_in_bytes(kSize);
110-
conn->write(*message, [&, conn, message](const Error& error) {
117+
auto holder = std::make_shared<NopHolder<MyNopType>>();
118+
MyNopType& object = holder->getObject();
119+
object.myIntField = kSize;
120+
conn->write(*holder, [&, conn, holder](const Error& error) {
111121
ASSERT_FALSE(error) << error.what();
112122
peers_->done(PeerGroup::kClient);
113123
});

‎tensorpipe/test/transport/shm/connection_test.cc

+18-10
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <tensorpipe/proto/core.pb.h>
109
#include <tensorpipe/test/transport/shm/shm_test.h>
1110

1211
#include <gtest/gtest.h>
12+
#include <nop/serializer.h>
13+
#include <nop/structure.h>
1314

1415
using namespace tensorpipe;
1516
using namespace tensorpipe::transport;
@@ -132,18 +133,26 @@ TEST_P(ShmTransportTest, QueueWrites) {
132133
});
133134
}
134135

135-
TEST_P(ShmTransportTest, ProtobufWriteWrapAround) {
136+
namespace {
137+
138+
struct MyNopType {
139+
std::string myStringField;
140+
NOP_STRUCTURE(MyNopType, myStringField);
141+
};
142+
143+
} // namespace
144+
145+
TEST_P(ShmTransportTest, NopWriteWrapAround) {
136146
constexpr int numMsg = 2;
137147
constexpr size_t kSize = (3 * kBufferSize) / 4;
138148

139149
testConnection(
140150
[&](std::shared_ptr<Connection> conn) {
141151
for (int i = 0; i < numMsg; ++i) {
142-
auto message =
143-
std::make_shared<tensorpipe::proto::ChannelAdvertisement>();
144-
conn->read(*message, [&, conn, message, i](const Error& error) {
152+
auto holder = std::make_shared<NopHolder<MyNopType>>();
153+
conn->read(*holder, [&, conn, holder, i](const Error& error) {
145154
ASSERT_FALSE(error) << error.what();
146-
ASSERT_EQ(message->domain_descriptor().length(), kSize);
155+
ASSERT_EQ(holder->getObject().myStringField.length(), kSize);
147156
if (i == numMsg - 1) {
148157
peers_->done(PeerGroup::kServer);
149158
}
@@ -153,10 +162,9 @@ TEST_P(ShmTransportTest, ProtobufWriteWrapAround) {
153162
},
154163
[&](std::shared_ptr<Connection> conn) {
155164
for (int i = 0; i < numMsg; ++i) {
156-
auto message =
157-
std::make_shared<tensorpipe::proto::ChannelAdvertisement>();
158-
message->set_domain_descriptor(std::string(kSize, 'B'));
159-
conn->write(*message, [&, conn, message, i](const Error& error) {
165+
auto holder = std::make_shared<NopHolder<MyNopType>>();
166+
holder->getObject().myStringField = std::string(kSize, 'B');
167+
conn->write(*holder, [&, conn, holder, i](const Error& error) {
160168
ASSERT_FALSE(error) << error.what();
161169
if (i == numMsg - 1) {
162170
peers_->done(PeerGroup::kClient);

‎tensorpipe/transport/connection.cc

+13-14
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,26 @@
88

99
#include <tensorpipe/transport/connection.h>
1010

11-
#include <google/protobuf/message_lite.h>
12-
1311
#include <tensorpipe/common/defs.h>
1412

1513
namespace tensorpipe {
1614
namespace transport {
1715

18-
void Connection::read(
19-
google::protobuf::MessageLite& message,
20-
read_proto_callback_fn fn) {
21-
read([&message, fn{std::move(fn)}](
16+
void Connection::read(AbstractNopHolder& object, read_nop_callback_fn fn) {
17+
read([&object, fn{std::move(fn)}](
2218
const Error& error, const void* ptr, size_t len) {
2319
if (!error) {
24-
message.ParseFromArray(ptr, len);
20+
nop::BufferReader reader(reinterpret_cast<const uint8_t*>(ptr), len);
21+
nop::Status<void> status = object.read(reader);
22+
TP_THROW_ASSERT_IF(status.has_error())
23+
<< "Error reading nop object: " << status.GetErrorMessage();
2524
}
2625
fn(error);
2726
});
2827
}
2928

30-
void Connection::write(
31-
const google::protobuf::MessageLite& message,
32-
write_callback_fn fn) {
33-
// FIXME use ByteSizeLong (introduced in a newer protobuf).
34-
const auto len = message.ByteSize();
29+
void Connection::write(const AbstractNopHolder& object, write_callback_fn fn) {
30+
const size_t len = object.getSize();
3531

3632
// Using a shared_ptr instead of unique_ptr because if the lambda captures a
3733
// unique_ptr then it becomes non-copyable, which prevents it from being
@@ -43,8 +39,11 @@ void Connection::write(
4339
auto buf = std::shared_ptr<uint8_t>(
4440
new uint8_t[len], std::default_delete<uint8_t[]>());
4541
auto ptr = buf.get();
46-
auto end = message.SerializeWithCachedSizesToArray(ptr);
47-
TP_DCHECK_EQ(end, ptr + len) << "Failed to serialize protobuf message.";
42+
43+
nop::BufferWriter writer(ptr, len);
44+
nop::Status<void> status = object.write(writer);
45+
TP_THROW_ASSERT_IF(status.has_error())
46+
<< "Error writing nop object: " << status.GetErrorMessage();
4847

4948
// Perform write and forward callback.
5049
write(

‎tensorpipe/transport/connection.h

+7-18
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,7 @@
1111
#include <functional>
1212

1313
#include <tensorpipe/common/error.h>
14-
15-
namespace google {
16-
namespace protobuf {
17-
18-
class MessageLite;
19-
20-
} // namespace protobuf
21-
} // namespace google
14+
#include <tensorpipe/common/nop.h>
2215

2316
namespace tensorpipe {
2417
namespace transport {
@@ -37,34 +30,30 @@ class Connection {
3730
virtual void write(const void* ptr, size_t length, write_callback_fn fn) = 0;
3831

3932
//
40-
// Helper functions for reading/writing protobuf messages.
33+
// Helper functions for reading/writing nop objects.
4134
//
4235

43-
// Read and parse protobuf message.
36+
// Read and parse a nop object.
4437
//
4538
// This function may be overridden by a subclass.
4639
//
4740
// For example, the shm transport may be able to bypass reading into a
4841
// temporary buffer and instead instead read directly from its peer's
4942
// ring buffer. This saves an allocation and a memory copy.
5043
//
51-
using read_proto_callback_fn = std::function<void(const Error& error)>;
44+
using read_nop_callback_fn = std::function<void(const Error& error)>;
5245

53-
virtual void read(
54-
google::protobuf::MessageLite& message,
55-
read_proto_callback_fn fn);
46+
virtual void read(AbstractNopHolder& object, read_nop_callback_fn fn);
5647

57-
// Serialize and write protobuf message.
48+
// Serialize and write nop object.
5849
//
5950
// This function may be overridden by a subclass.
6051
//
6152
// For example, the shm transport may be able to bypass serialization
6253
// into a temporary buffer and instead instead serialize directly into
6354
// its peer's ring buffer. This saves an allocation and a memory copy.
6455
//
65-
virtual void write(
66-
const google::protobuf::MessageLite& message,
67-
write_callback_fn fn);
56+
virtual void write(const AbstractNopHolder& object, write_callback_fn fn);
6857

6958
// Tell the connection what its identifier is.
7059
//

‎tensorpipe/transport/shm/connection.h

+15-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
#include <tensorpipe/transport/defs.h>
1515
#include <tensorpipe/transport/shm/context.h>
1616

17+
namespace google {
18+
namespace protobuf {
19+
20+
class MessageLite;
21+
22+
} // namespace protobuf
23+
} // namespace google
24+
1725
namespace tensorpipe {
1826
namespace transport {
1927
namespace shm {
@@ -44,14 +52,17 @@ class Connection final : public transport::Connection {
4452

4553
// Queue a read operation.
4654
void read(read_callback_fn fn) override;
47-
void read(google::protobuf::MessageLite& message, read_proto_callback_fn fn)
48-
override;
55+
using transport::Connection::read;
56+
using read_proto_callback_fn = std::function<void(const Error& error)>;
57+
void read(google::protobuf::MessageLite& message, read_proto_callback_fn fn);
4958
void read(void* ptr, size_t length, read_callback_fn fn) override;
5059

5160
// Perform a write operation.
5261
void write(const void* ptr, size_t length, write_callback_fn fn) override;
53-
void write(const google::protobuf::MessageLite& message, write_callback_fn fn)
54-
override;
62+
using transport::Connection::write;
63+
void write(
64+
const google::protobuf::MessageLite& message,
65+
write_callback_fn fn);
5566

5667
// Tell the connection what its identifier is.
5768
void setId(std::string id) override;

0 commit comments

Comments
 (0)
This repository has been archived.