summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--btpd/net.c154
-rw-r--r--btpd/peer.h7
2 files changed, 80 insertions, 81 deletions
diff --git a/btpd/net.c b/btpd/net.c
index 90f1203..55731a7 100644
--- a/btpd/net.c
+++ b/btpd/net.c
@@ -112,10 +112,6 @@ net_write(struct peer *p, unsigned long wmax)
     }
     if (!BTPDQ_EMPTY(&p->outq))
 	event_add(&p->out_ev, WRITE_TIMEOUT);
-    else if (p->flags & PF_WRITE_CLOSE) {
-	btpd_log(BTPD_L_CONN, "Closed because of write flag.\n");
-	peer_kill(p);
-    }
 
     return nwritten;
 }
@@ -128,12 +124,12 @@ net_set_state(struct peer *p, int state, size_t size)
 }
 
 static int
-net_dispatch_msg(struct peer *p, uint32_t mlen, uint8_t mnum, uint8_t *buf)
+net_dispatch_msg(struct peer *p, uint8_t *buf)
 {
     uint32_t index, begin, length;
     int res = 0;
 
-    switch (mnum) {
+    switch (p->msg_num) {
     case MSG_CHOKE:
 	peer_on_choke(p);
 	break;
@@ -147,16 +143,16 @@ net_dispatch_msg(struct peer *p, uint32_t mlen, uint8_t mnum, uint8_t *buf)
 	peer_on_uninterest(p);
 	break;
     case MSG_HAVE:
-	peer_on_have(p, net_read32(buf + 5));
+	peer_on_have(p, net_read32(buf));
 	break;
     case MSG_BITFIELD:
-	peer_on_bitfield(p, buf + 5);
+	peer_on_bitfield(p, buf);
 	break;
     case MSG_REQUEST:
 	if ((p->flags & (PF_P_WANT|PF_I_CHOKE)) == PF_P_WANT) {
-	    index = net_read32(buf + 5);
-	    begin = net_read32(buf + 9);
-	    length = net_read32(buf + 13);
+	    index = net_read32(buf);
+	    begin = net_read32(buf + 4);
+	    length = net_read32(buf + 8);
 	    if ((length > PIECE_BLOCKLEN
 		    || index >= p->tp->meta.npieces
 		    || !has_bit(p->tp->piece_field, index)
@@ -168,16 +164,16 @@ net_dispatch_msg(struct peer *p, uint32_t mlen, uint8_t mnum, uint8_t *buf)
 	}
 	break;
     case MSG_CANCEL:
-	index = net_read32(buf + 5);
-	begin = net_read32(buf + 9);
-	length = net_read32(buf + 13);
+	index = net_read32(buf);
+	begin = net_read32(buf + 4);
+	length = net_read32(buf + 8);
 	peer_on_cancel(p, index, begin, length);
 	break;
     case MSG_PIECE:
-	index = net_read32(buf + 5);
-	begin = net_read32(buf + 9);
-	length = mlen - 9;
-	peer_on_piece(p, index, begin, length, buf + 13);
+	index = net_read32(buf);
+	begin = net_read32(buf + 4);
+	length = p->msg_len - 9;
+	peer_on_piece(p, index, begin, length, buf + 8);
 	break;
     default:
 	abort();
@@ -186,9 +182,10 @@ net_dispatch_msg(struct peer *p, uint32_t mlen, uint8_t mnum, uint8_t *buf)
 }
 
 static int
-net_mh_ok(struct peer *p, uint32_t mlen, uint8_t mnum)
+net_mh_ok(struct peer *p)
 {
-    switch (mnum) {
+    uint32_t mlen = p->msg_len;
+    switch (p->msg_num) {
     case MSG_CHOKE:
     case MSG_UNCHOKE:
     case MSG_INTEREST:
@@ -218,20 +215,15 @@ net_progress(struct peer *p, size_t length)
 }
 
 static int
-net_state_foo(struct peer *p, struct io_buffer *iob)
+net_state(struct peer *p, struct io_buffer *iob)
 {
-    uint32_t mlen;
-    uint32_t mnum;
-
     switch (p->net_state) {
     case SHAKE_PSTR:
-	assert(iob->buf_len >= 28);
 	if (bcmp(iob->buf, "\x13""BitTorrent protocol", 20) != 0)
 	    goto bad;
 	net_set_state(p, SHAKE_INFO, 20);
 	return 28;
     case SHAKE_INFO:
-	assert(iob->buf_len >= 20);
 	if (p->flags & PF_INCOMING) {
 	    struct torrent *tp = torrent_get_by_hash(iob->buf);
 	    if (tp == NULL)
@@ -243,7 +235,6 @@ net_state_foo(struct peer *p, struct io_buffer *iob)
 	net_set_state(p, SHAKE_ID, 20);
 	return 20;
     case SHAKE_ID:
-	assert(iob->buf_len >= 20);
 	if ((torrent_has_peer(p->tp, iob->buf)
 		|| bcmp(iob->buf, btpd.peer_id, 20) == 0))
 	    goto bad;
@@ -262,43 +253,35 @@ net_state_foo(struct peer *p, struct io_buffer *iob)
 	net_set_state(p, NET_MSGSIZE, 4);
 	return 20;
     case NET_MSGSIZE:
-	assert(iob->buf_len >= 4);
-	if (bcmp(iob->buf, "\0\0\0\0", 4) == 0)
-	    return 4;
-	else {
-	    net_set_state(p, NET_MSGHEAD, 5);
-	    return 0;
-	}
+	p->msg_len = net_read32(iob->buf);
+	if (p->msg_len != 0)
+	    net_set_state(p, NET_MSGHEAD, 1);
+	return 4;
     case NET_MSGHEAD:
-	assert(iob->buf_len >= 5);
-	mlen = net_read32(iob->buf);
-	mnum = iob->buf[4];
-	if (!net_mh_ok(p, mlen, mnum)) {
+	p->msg_num = iob->buf[0];
+	if (!net_mh_ok(p)) {
 	    btpd_log(BTPD_L_ERROR, "error in head\n");
 	    goto bad;
-	} else if (mlen == 1) {
-	    if (net_dispatch_msg(p, mlen, mnum, iob->buf) != 0) {
+	} else if (p->msg_len == 1) {
+	    if (net_dispatch_msg(p, iob->buf) != 0) {
 		btpd_log(BTPD_L_ERROR, "error in dispatch\n");
 		goto bad;
-		}
+	    }
 	    net_set_state(p, NET_MSGSIZE, 4);
-	    return 5;
 	} else {
-	    uint8_t nstate = mnum == MSG_PIECE ? NET_MSGPIECE : NET_MSGBODY;
-	    net_set_state(p, nstate, mlen + 4);
-	    return 0;
+	    uint8_t nstate =
+		p->msg_num == MSG_PIECE ? NET_MSGPIECE : NET_MSGBODY;
+	    net_set_state(p, nstate, p->msg_len - 1);
 	}
+	return 1;
     case NET_MSGPIECE:
     case NET_MSGBODY:
-	mlen = net_read32(iob->buf);
-	mnum = iob->buf[4];
-	assert(iob->buf_len >= mlen + 4);
-	if (net_dispatch_msg(p, mlen, mnum, iob->buf) != 0) {
+	if (net_dispatch_msg(p, iob->buf) != 0) {
 	    btpd_log(BTPD_L_ERROR, "error in dispatch\n");
 	    goto bad;
 	}
 	net_set_state(p, NET_MSGSIZE, 4);
-	return mlen + 4;
+	return p->msg_len - 1;
     default:
 	abort();
     }
@@ -314,23 +297,25 @@ bad:
 static unsigned long
 net_read(struct peer *p, unsigned long rmax)
 {
-    size_t baggage = p->net_in.buf_len;
-    char buf[GRBUFLEN + baggage];
-    struct io_buffer sbuf = { baggage, sizeof(buf), buf };
-    if (baggage > 0) {
-	bcopy(p->net_in.buf, buf, baggage);
-	free(p->net_in.buf);
-	p->net_in.buf = NULL;
-	p->net_in.buf_off = 0;
-	p->net_in.buf_len = 0;
-    }
+    size_t rest = p->net_in.buf_len - p->net_in.buf_off;
+    char buf[GRBUFLEN];
+    struct iovec iov[2] = {
+	{
+	    p->net_in.buf + p->net_in.buf_off,
+	    rest
+	}, {
+	    buf,
+	    sizeof(buf)
+	}
+    };
 
-    if (rmax > 0)
-	rmax = min(rmax, sbuf.buf_len - sbuf.buf_off);
-    else
-	rmax = sbuf.buf_len - sbuf.buf_off;
+    if (rmax > 0) {
+	if (iov[0].iov_len > rmax)
+	    iov[0].iov_len = rmax;
+	iov[1].iov_len = min(rmax - iov[0].iov_len, iov[1].iov_len);
+    }
 
-    ssize_t nread = read(p->sd, sbuf.buf + sbuf.buf_off, rmax);
+    ssize_t nread = readv(p->sd, iov, 2);
     if (nread < 0 && errno == EAGAIN)
 	goto out;
     else if (nread < 0) {
@@ -343,23 +328,36 @@ net_read(struct peer *p, unsigned long rmax)
 	return 0;
     }
 
-    sbuf.buf_len = sbuf.buf_off + nread;
-    sbuf.buf_off = 0;
-    while (p->state_bytes <= sbuf.buf_len) {
-	ssize_t chomped = net_state_foo(p, &sbuf);
-	if (chomped < 0)
+    if (rest > 0) {
+	if (nread < rest) {
+	    p->net_in.buf_off += nread;
+	    net_progress(p, nread);
+	    goto out;
+	}
+	net_progress(p, rest);
+	if (net_state(p, &p->net_in) < 0)
 	    return nread;
-	sbuf.buf += chomped;
-	sbuf.buf_len -= chomped;
-	baggage = 0;
+	free(p->net_in.buf);
+	bzero(&p->net_in, sizeof(p->net_in));
     }
 
-    net_progress(p, sbuf.buf_len - baggage);
+    struct io_buffer iob = { 0, nread - rest, buf };
+
+    while (p->state_bytes <= iob.buf_len) {
+	net_progress(p, p->state_bytes);
+	ssize_t chomped = net_state(p, &iob);
+	if (chomped < 0)
+	    return nread;
+	iob.buf += chomped;
+	iob.buf_len -= chomped;
+    }
 
-    if (sbuf.buf_len > 0) {
-	p->net_in = sbuf;
-	p->net_in.buf = btpd_malloc(sbuf.buf_len);
-	bcopy(sbuf.buf, p->net_in.buf, sbuf.buf_len);
+    if (iob.buf_len > 0) {
+	net_progress(p, iob.buf_len);
+	p->net_in.buf_off = iob.buf_len;
+	p->net_in.buf_len = p->state_bytes;
+	p->net_in.buf = btpd_malloc(p->state_bytes);
+	bcopy(iob.buf, p->net_in.buf, iob.buf_len);
     }
 
 out:
diff --git a/btpd/peer.h b/btpd/peer.h
index c92f925..b4e007f 100644
--- a/btpd/peer.h
+++ b/btpd/peer.h
@@ -8,9 +8,8 @@
 #define PF_ON_READQ	 0x10
 #define PF_ON_WRITEQ	 0x20
 #define PF_ATTACHED	 0x40
-#define PF_WRITE_CLOSE	 0x80	/* Close connection after writing all data */
-#define PF_NO_REQUESTS	0x100
-#define PF_INCOMING	0x200
+#define PF_NO_REQUESTS	 0x80
+#define PF_INCOMING	0x100
 
 #define RATEHISTORY 20
 #define MAXPIECEMSGS 128
@@ -52,6 +51,8 @@ struct peer {
 
     size_t state_bytes;
     uint8_t net_state;
+    uint8_t msg_num;
+    uint32_t msg_len;
     struct io_buffer net_in;
 
     BTPDQ_ENTRY(peer) cm_entry;