Skip to content

Commit

Permalink
Update channel passing to use reference ids instead of stream ids
Browse files Browse the repository at this point in the history
Signed-off-by: Derek McGowan <[email protected]> (github: dmcgowan)
  • Loading branch information
dmcgowan committed Jul 23, 2014
1 parent 934c215 commit de99fa2
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 47 deletions.
26 changes: 9 additions & 17 deletions spdy/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,17 @@ func (s *Transport) encodeChannel(v reflect.Value) ([]byte, error) {
}

// Get stream identifier?
streamID := rc.stream.Identifier()
var buf [9]byte
referenceID := rc.referenceID
buf := make([]byte, 9)
if rc.direction == inbound {
buf[0] = 0x02 // Reverse direction
} else if rc.direction == outbound {
buf[0] = 0x01 // Reverse direction
} else {
return nil, errors.New("invalid direction")
}
written := binary.PutUvarint(buf[1:], uint64(streamID))
if written > 4 {
return nil, errors.New("wrote unexpected stream id size")
}
return buf[:(written + 1)], nil
binary.BigEndian.PutUint64(buf[1:], referenceID)
return buf, nil
}

func (s *Transport) decodeChannel(v reflect.Value, b []byte) error {
Expand All @@ -49,17 +46,12 @@ func (s *Transport) decodeChannel(v reflect.Value, b []byte) error {
return errors.New("unexpected direction")
}

streamID, readN := binary.Uvarint(b[1:])
if readN > 4 {
return errors.New("read unexpected stream id size")
}
stream := s.conn.FindStream(uint32(streamID))
if stream == nil {
return errors.New("stream does not exist")
referenceID := binary.BigEndian.Uint64(b[1:])
c := s.getChannel(referenceID)
if c == nil {
return errors.New("channel does not exist")
}
rc.session = s
rc.stream = stream
v.Set(reflect.ValueOf(rc))
v.Set(reflect.ValueOf(*c))

return nil
}
Expand Down
125 changes: 95 additions & 30 deletions spdy/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ type Transport struct {
conn *spdystream.Connection
handler codec.Handle

streamChan chan *spdystream.Stream
receiverChan chan *channel
channelC *sync.Cond
channels map[uint64]*channel

referenceLock sync.Mutex
referenceCounter uint64
Expand All @@ -45,9 +47,11 @@ type Transport struct {
}

type channel struct {
stream *spdystream.Stream
session *Transport
direction direction
referenceID uint64
parentID uint64
stream *spdystream.Stream
session *Transport
direction direction
}

// NewClientTransport creates a new stream transport from the
Expand All @@ -72,7 +76,9 @@ func newSession(conn net.Conn, server bool) (*Transport, error) {
referenceCounter = 1
}
session := &Transport{
streamChan: make(chan *spdystream.Stream),
receiverChan: make(chan *channel),
channelC: sync.NewCond(new(sync.Mutex)),
channels: make(map[uint64]*channel),
referenceCounter: referenceCounter,
byteStreamC: sync.NewCond(new(sync.Mutex)),
byteStreams: make(map[uint64]*byteStream),
Expand All @@ -95,15 +101,16 @@ func newSession(conn net.Conn, server bool) (*Transport, error) {

func (s *Transport) newStreamHandler(stream *spdystream.Stream) {
referenceIDString := stream.Headers().Get("libchan-ref")
parentIDString := stream.Headers().Get("libchan-parent-ref")

returnHeaders := http.Header{}
finish := false
if referenceIDString != "" {
referenceID, parseErr := strconv.ParseUint(referenceIDString, 10, 64)
if parseErr != nil {
returnHeaders.Set("status", "400")
finish = true
} else {
referenceID, parseErr := strconv.ParseUint(referenceIDString, 10, 64)
if parseErr != nil {
returnHeaders.Set("status", "400")
finish = true
} else {
if parentIDString == "" {
byteStream := &byteStream{
ReferenceID: referenceID,
Stream: stream,
Expand All @@ -112,10 +119,29 @@ func (s *Transport) newStreamHandler(stream *spdystream.Stream) {
s.byteStreams[referenceID] = byteStream
s.byteStreamC.Broadcast()
s.byteStreamC.L.Unlock()
}
} else {
if stream.Parent() == nil {
s.streamChan <- stream

} else {
parentID, parseErr := strconv.ParseUint(parentIDString, 10, 64)
if parseErr != nil {
returnHeaders.Set("status", "400")
finish = true
} else {
c := &channel{
referenceID: referenceID,
parentID: parentID,
stream: stream,
session: s,
}
s.channelC.L.Lock()
s.channels[referenceID] = c
s.channelC.Broadcast()
s.channelC.L.Unlock()

if parentID == 0 {
c.direction = inbound
s.receiverChan <- c
}
}
}
}

Expand All @@ -133,6 +159,17 @@ func (s *Transport) getByteStream(referenceID uint64) *byteStream {
return bs
}

func (s *Transport) getChannel(referenceID uint64) *channel {
s.channelC.L.Lock()
c, ok := s.channels[referenceID]
if !ok {
s.channelC.Wait()
c, ok = s.channels[referenceID]
}
s.channelC.L.Unlock()
return c
}

func (s *Transport) dial(referenceID uint64) (*byteStream, error) {
headers := http.Header{}
headers.Set("libchan-ref", strconv.FormatUint(referenceID, 10))
Expand All @@ -147,11 +184,16 @@ func (s *Transport) dial(referenceID uint64) (*byteStream, error) {
return bs, nil
}

func (s *Transport) createByteStream() (io.ReadWriteCloser, error) {
func (s *Transport) nextReferenceID() uint64 {
s.referenceLock.Lock()
referenceID := s.referenceCounter
s.referenceCounter = referenceID + 2
s.referenceLock.Unlock()
return referenceID
}

func (s *Transport) createByteStream() (io.ReadWriteCloser, error) {
referenceID := s.nextReferenceID()

byteStream, bsErr := s.dial(referenceID)
if bsErr != nil {
Expand Down Expand Up @@ -233,43 +275,66 @@ func (s *Transport) Close() error {
// end will get picked up on the remote end through the remote calling
// WaitReceiveChannel.
func (s *Transport) NewSendChannel() (libchan.Sender, error) {
stream, streamErr := s.conn.CreateStream(http.Header{}, nil, false)
referenceID := s.nextReferenceID()
headers := http.Header{}
headers.Set("libchan-ref", strconv.FormatUint(referenceID, 10))
headers.Set("libchan-parent-ref", "0")

stream, streamErr := s.conn.CreateStream(headers, nil, false)
if streamErr != nil {
return nil, streamErr
}
return &channel{stream: stream, session: s, direction: outbound}, nil
c := &channel{
referenceID: referenceID,
stream: stream,
session: s,
direction: outbound,
}

s.channelC.L.Lock()
s.channels[referenceID] = c
s.channelC.L.Unlock()

return c, nil
}

// WaitReceiveChannel waits for a new channel be created by a remote
// call to NewSendChannel.
func (s *Transport) WaitReceiveChannel() (libchan.Receiver, error) {
stream, ok := <-s.streamChan
r, ok := <-s.receiverChan
if !ok {
return nil, io.EOF
}

return &channel{
stream: stream,
session: s,
direction: inbound,
}, nil
return r, nil
}

func (c *channel) createSubChannel(direction direction) (libchan.Sender, libchan.Receiver, error) {
if c.direction == inbound {
return nil, nil, errors.New("cannot create sub channel of an inbound channel")
}
referenceID := c.session.nextReferenceID()
headers := http.Header{}
headers.Set("libchan-ref", strconv.FormatUint(referenceID, 10))
headers.Set("libchan-parent-ref", strconv.FormatUint(c.referenceID, 10))

stream, streamErr := c.stream.CreateSubStream(http.Header{}, false)
stream, streamErr := c.stream.CreateSubStream(headers, false)
if streamErr != nil {
return nil, nil, streamErr
}
channel := &channel{
stream: stream,
session: c.session,
direction: direction,
subChannel := &channel{
referenceID: referenceID,
parentID: c.referenceID,
stream: stream,
session: c.session,
direction: direction,
}
return channel, channel, nil

c.session.channelC.L.Lock()
c.session.channels[referenceID] = subChannel
c.session.channelC.L.Unlock()

return subChannel, subChannel, nil
}

// CreateByteStream creates a new byte stream using an underlying
Expand Down

0 comments on commit de99fa2

Please sign in to comment.