13
13
14
14
#include < tensorpipe/channel/error.h>
15
15
#include < tensorpipe/channel/helpers.h>
16
+ #include < tensorpipe/channel/mpt/nop_types.h>
16
17
#include < tensorpipe/common/callback.h>
17
18
#include < tensorpipe/common/defs.h>
18
19
#include < tensorpipe/common/error_macros.h>
19
- #include < tensorpipe/proto/channel/mpt.pb.h>
20
20
#include < tensorpipe/transport/context.h>
21
21
#include < tensorpipe/transport/error.h>
22
22
#include < tensorpipe/transport/listener.h>
@@ -103,7 +103,7 @@ class Channel::Impl : public std::enable_shared_from_this<Channel::Impl> {
103
103
void closeFromLoop_ ();
104
104
105
105
// Called when client reads the server's hello on backbone connection
106
- void onClientReadHelloOnConnection_ (const proto:: Packet& pbPacketIn );
106
+ void onClientReadHelloOnConnection_ (const Packet& nopPacketIn );
107
107
108
108
// Called when server accepts new client connection for lane
109
109
void onServerAcceptOfLane_ (
@@ -209,23 +209,27 @@ void Channel::Impl::initFromLoop_() {
209
209
TP_DCHECK_EQ (state_, UNINITIALIZED);
210
210
if (endpoint_ == Endpoint::kConnect ) {
211
211
state_ = CLIENT_READING_HELLO;
212
- auto packet = std::make_shared<proto::Packet>();
213
- TP_VLOG (6 ) << " Channel " << id_ << " reading proto (server hello)" ;
214
- connection_->read (*packet, lazyCallbackWrapper_ ([packet](Impl& impl) {
215
- TP_VLOG (6 ) << " Channel " << impl.id_
216
- << " done reading proto (server hello)" ;
217
- impl.onClientReadHelloOnConnection_ (*packet);
218
- }));
212
+ auto nopHolderIn = std::make_shared<NopHolder<Packet>>();
213
+ TP_VLOG (6 ) << " Channel " << id_ << " reading nop object (server hello)" ;
214
+ connection_->read (
215
+ *nopHolderIn, lazyCallbackWrapper_ ([nopHolderIn](Impl& impl) {
216
+ TP_VLOG (6 ) << " Channel " << impl.id_
217
+ << " done reading nop object (server hello)" ;
218
+ impl.onClientReadHelloOnConnection_ (nopHolderIn->getObject ());
219
+ }));
219
220
} else if (endpoint_ == Endpoint::kListen ) {
220
221
state_ = SERVER_ACCEPTING_LANES;
221
222
const std::vector<std::string>& addresses = context_->addresses ();
222
223
TP_DCHECK_EQ (addresses.size (), numLanes_);
223
- auto packet = std::make_shared<proto::Packet>();
224
- proto::ServerHello* pbServerHello = packet->mutable_server_hello ();
224
+ auto nopHolderOut = std::make_shared<NopHolder<Packet>>();
225
+ Packet& nopPacket = nopHolderOut->getObject ();
226
+ nopPacket.Become (nopPacket.index_of <ServerHello>());
227
+ ServerHello& nopServerHello = *nopPacket.get <ServerHello>();
225
228
for (uint64_t laneIdx = 0 ; laneIdx < numLanes_; ++laneIdx) {
226
- proto::LaneAdvertisement* pbLaneAdvertisement =
227
- pbServerHello->add_lane_advertisements ();
228
- pbLaneAdvertisement->set_address (addresses[laneIdx]);
229
+ nopServerHello.laneAdvertisements .emplace_back ();
230
+ LaneAdvertisement& nopLaneAdvertisement =
231
+ nopServerHello.laneAdvertisements .back ();
232
+ nopLaneAdvertisement.address = addresses[laneIdx];
229
233
TP_VLOG (6 ) << " Channel " << id_ << " requesting connection (for lane "
230
234
<< laneIdx << " )" ;
231
235
uint64_t token = context_->registerConnectionRequest (
@@ -240,14 +244,15 @@ void Channel::Impl::initFromLoop_() {
240
244
impl.onServerAcceptOfLane_ (laneIdx, std::move (connection));
241
245
}));
242
246
laneRegistrationIds_.emplace (laneIdx, token);
243
- pbLaneAdvertisement-> set_registration_id ( token) ;
247
+ nopLaneAdvertisement. registrationId = token;
244
248
numLanesBeingAccepted_++;
245
249
}
246
- TP_VLOG (6 ) << " Channel " << id_ << " writing proto (server hello)" ;
247
- connection_->write (*packet, lazyCallbackWrapper_ ([packet](Impl& impl) {
248
- TP_VLOG (6 ) << " Channel " << impl.id_
249
- << " done writing proto (server hello)" ;
250
- }));
250
+ TP_VLOG (6 ) << " Channel " << id_ << " writing nop object (server hello)" ;
251
+ connection_->write (
252
+ *nopHolderOut, lazyCallbackWrapper_ ([nopHolderOut](Impl& impl) {
253
+ TP_VLOG (6 ) << " Channel " << impl.id_
254
+ << " done writing nop object (server hello)" ;
255
+ }));
251
256
} else {
252
257
TP_THROW_ASSERT () << " unknown endpoint" ;
253
258
}
@@ -406,29 +411,31 @@ void Channel::Impl::setIdFromLoop_(std::string id) {
406
411
id_ = std::move (id);
407
412
}
408
413
409
- void Channel::Impl::onClientReadHelloOnConnection_ (
410
- const proto::Packet& pbPacketIn) {
414
+ void Channel::Impl::onClientReadHelloOnConnection_ (const Packet& nopPacketIn) {
411
415
TP_DCHECK (loop_.inLoop ());
412
416
TP_DCHECK_EQ (state_, CLIENT_READING_HELLO);
413
- TP_DCHECK_EQ (pbPacketIn. type_case (), proto::Packet:: kServerHello );
417
+ TP_DCHECK_EQ (nopPacketIn. index (), nopPacketIn. index_of <ServerHello>() );
414
418
415
- const proto:: ServerHello& pbServerHello = pbPacketIn. server_hello ();
416
- TP_DCHECK_EQ (pbServerHello. lane_advertisements () .size (), numLanes_);
419
+ const ServerHello& nopServerHello = *nopPacketIn. get <ServerHello> ();
420
+ TP_DCHECK_EQ (nopServerHello. laneAdvertisements .size (), numLanes_);
417
421
lanes_.resize (numLanes_);
418
422
for (uint64_t laneIdx = 0 ; laneIdx < numLanes_; ++laneIdx) {
419
- const proto:: LaneAdvertisement& pbLaneAdvertisement =
420
- pbServerHello. lane_advertisements (). Get ( laneIdx) ;
423
+ const LaneAdvertisement& nopLaneAdvertisement =
424
+ nopServerHello. laneAdvertisements [ laneIdx] ;
421
425
std::shared_ptr<transport::Connection> lane =
422
- context_->connect (laneIdx, pbLaneAdvertisement.address ());
423
- auto pbPacketOut = std::make_shared<proto::Packet>();
424
- proto::ClientHello* pbClientHello = pbPacketOut->mutable_client_hello ();
425
- pbClientHello->set_registration_id (pbLaneAdvertisement.registration_id ());
426
- TP_VLOG (6 ) << " Channel " << id_ << " writing proto (client hello) on lane "
427
- << laneIdx;
426
+ context_->connect (laneIdx, nopLaneAdvertisement.address );
427
+ auto nopHolderOut = std::make_shared<NopHolder<Packet>>();
428
+ Packet& nopPacket = nopHolderOut->getObject ();
429
+ nopPacket.Become (nopPacket.index_of <ClientHello>());
430
+ ClientHello& nopClientHello = *nopPacket.get <ClientHello>();
431
+ nopClientHello.registrationId = nopLaneAdvertisement.registrationId ;
432
+ TP_VLOG (6 ) << " Channel " << id_
433
+ << " writing nop object (client hello) on lane " << laneIdx;
428
434
lane->write (
429
- *pbPacketOut, lazyCallbackWrapper_ ([laneIdx, pbPacketOut](Impl& impl) {
435
+ *nopHolderOut,
436
+ lazyCallbackWrapper_ ([laneIdx, nopHolderOut](Impl& impl) {
430
437
TP_VLOG (6 ) << " Channel " << impl.id_
431
- << " done writing proto (client hello) on lane "
438
+ << " done writing nop object (client hello) on lane "
432
439
<< laneIdx;
433
440
}));
434
441
lanes_[laneIdx] = std::move (lane);
0 commit comments