Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Merge branch 'virtio-vsock-some-updates-for-msg_peek-flag'

Arseniy Krasnov says:

====================
virtio/vsock: some updates for MSG_PEEK flag

This patchset does several things around MSG_PEEK flag support. In
general words it reworks MSG_PEEK test and adds support for this flag
in SOCK_SEQPACKET logic. Here is per-patch description:

1) This is cosmetic change for SOCK_STREAM implementation of MSG_PEEK:
1) I think there is no need of "safe" mode walk here as there is no
"unlink" of skbs inside loop (it is MSG_PEEK mode - we don't change
queue).
2) Nested while loop is removed: in case of MSG_PEEK we just walk
over skbs and copy data from each one. I guess this nested loop
even didn't behave as loop - it always executed just for single
iteration.

2) This adds MSG_PEEK support for SOCK_SEQPACKET. It could be implemented
be reworking MSG_PEEK callback for SOCK_STREAM to support SOCK_SEQPACKET
also, but I think it will be more simple and clear from potential
bugs to implemented it as separate function thus not mixing logics
for both types of socket. So I've added it as dedicated function.

3) This is reworked MSG_PEEK test for SOCK_STREAM. Previous version just
sent single byte, then tried to read it with MSG_PEEK flag, then read
it in normal way. New version is more complex: now sender uses buffer
instead of single byte and this buffer is initialized with random
values. Receiver tests several things:
1) Read empty socket with MSG_PEEK flag.
2) Read part of buffer with MSG_PEEK flag.
3) Read whole buffer with MSG_PEEK flag, then checks that it is same
as buffer from 2) (limited by size of buffer from 2) of course).
4) Read whole buffer without any flags, then checks that it is same
as buffer from 3).

4) This is MSG_PEEK test for SOCK_SEQPACKET. It works in the same way
as for SOCK_STREAM, except it also checks combination of MSG_TRUNC
and MSG_PEEK.
====================

Link: https://lore.kernel.org/r/20230725172912.1659970-1-AVKrasnov@sberdevices.ru
Signed-off-by: Paolo Abeni <pabeni@redhat.com>

+213 -37
+82 -28
net/vmw_vsock/virtio_transport_common.c
··· 348 348 size_t len) 349 349 { 350 350 struct virtio_vsock_sock *vvs = vsk->trans; 351 - size_t bytes, total = 0, off; 352 - struct sk_buff *skb, *tmp; 353 - int err = -EFAULT; 351 + struct sk_buff *skb; 352 + size_t total = 0; 353 + int err; 354 354 355 355 spin_lock_bh(&vvs->rx_lock); 356 356 357 - skb_queue_walk_safe(&vvs->rx_queue, skb, tmp) { 358 - off = 0; 357 + skb_queue_walk(&vvs->rx_queue, skb) { 358 + size_t bytes; 359 + 360 + bytes = len - total; 361 + if (bytes > skb->len) 362 + bytes = skb->len; 363 + 364 + spin_unlock_bh(&vvs->rx_lock); 365 + 366 + /* sk_lock is held by caller so no one else can dequeue. 367 + * Unlock rx_lock since memcpy_to_msg() may sleep. 368 + */ 369 + err = memcpy_to_msg(msg, skb->data, bytes); 370 + if (err) 371 + goto out; 372 + 373 + total += bytes; 374 + 375 + spin_lock_bh(&vvs->rx_lock); 359 376 360 377 if (total == len) 361 378 break; 362 - 363 - while (total < len && off < skb->len) { 364 - bytes = len - total; 365 - if (bytes > skb->len - off) 366 - bytes = skb->len - off; 367 - 368 - /* sk_lock is held by caller so no one else can dequeue. 369 - * Unlock rx_lock since memcpy_to_msg() may sleep. 370 - */ 371 - spin_unlock_bh(&vvs->rx_lock); 372 - 373 - err = memcpy_to_msg(msg, skb->data + off, bytes); 374 - if (err) 375 - goto out; 376 - 377 - spin_lock_bh(&vvs->rx_lock); 378 - 379 - total += bytes; 380 - off += bytes; 381 - } 382 379 } 383 380 384 381 spin_unlock_bh(&vvs->rx_lock); ··· 458 461 if (total) 459 462 err = total; 460 463 return err; 464 + } 465 + 466 + static ssize_t 467 + virtio_transport_seqpacket_do_peek(struct vsock_sock *vsk, 468 + struct msghdr *msg) 469 + { 470 + struct virtio_vsock_sock *vvs = vsk->trans; 471 + struct sk_buff *skb; 472 + size_t total, len; 473 + 474 + spin_lock_bh(&vvs->rx_lock); 475 + 476 + if (!vvs->msg_count) { 477 + spin_unlock_bh(&vvs->rx_lock); 478 + return 0; 479 + } 480 + 481 + total = 0; 482 + len = msg_data_left(msg); 483 + 484 + skb_queue_walk(&vvs->rx_queue, skb) { 485 + struct virtio_vsock_hdr *hdr; 486 + 487 + if (total < len) { 488 + size_t bytes; 489 + int err; 490 + 491 + bytes = len - total; 492 + if (bytes > skb->len) 493 + bytes = skb->len; 494 + 495 + spin_unlock_bh(&vvs->rx_lock); 496 + 497 + /* sk_lock is held by caller so no one else can dequeue. 498 + * Unlock rx_lock since memcpy_to_msg() may sleep. 499 + */ 500 + err = memcpy_to_msg(msg, skb->data, bytes); 501 + if (err) 502 + return err; 503 + 504 + spin_lock_bh(&vvs->rx_lock); 505 + } 506 + 507 + total += skb->len; 508 + hdr = virtio_vsock_hdr(skb); 509 + 510 + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) { 511 + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR) 512 + msg->msg_flags |= MSG_EOR; 513 + 514 + break; 515 + } 516 + } 517 + 518 + spin_unlock_bh(&vvs->rx_lock); 519 + 520 + return total; 461 521 } 462 522 463 523 static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk, ··· 611 557 int flags) 612 558 { 613 559 if (flags & MSG_PEEK) 614 - return -EOPNOTSUPP; 615 - 616 - return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags); 560 + return virtio_transport_seqpacket_do_peek(vsk, msg); 561 + else 562 + return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags); 617 563 } 618 564 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue); 619 565
+131 -9
tools/testing/vsock/vsock_test.c
··· 255 255 close(fds[i]); 256 256 } 257 257 258 - static void test_stream_msg_peek_client(const struct test_opts *opts) 259 - { 260 - int fd; 258 + #define MSG_PEEK_BUF_LEN 64 261 259 262 - fd = vsock_stream_connect(opts->peer_cid, 1234); 260 + static void test_msg_peek_client(const struct test_opts *opts, 261 + bool seqpacket) 262 + { 263 + unsigned char buf[MSG_PEEK_BUF_LEN]; 264 + ssize_t send_size; 265 + int fd; 266 + int i; 267 + 268 + if (seqpacket) 269 + fd = vsock_seqpacket_connect(opts->peer_cid, 1234); 270 + else 271 + fd = vsock_stream_connect(opts->peer_cid, 1234); 272 + 263 273 if (fd < 0) { 264 274 perror("connect"); 265 275 exit(EXIT_FAILURE); 266 276 } 267 277 268 - send_byte(fd, 1, 0); 278 + for (i = 0; i < sizeof(buf); i++) 279 + buf[i] = rand() & 0xFF; 280 + 281 + control_expectln("SRVREADY"); 282 + 283 + send_size = send(fd, buf, sizeof(buf), 0); 284 + 285 + if (send_size < 0) { 286 + perror("send"); 287 + exit(EXIT_FAILURE); 288 + } 289 + 290 + if (send_size != sizeof(buf)) { 291 + fprintf(stderr, "Invalid send size %zi\n", send_size); 292 + exit(EXIT_FAILURE); 293 + } 294 + 269 295 close(fd); 270 296 } 271 297 272 - static void test_stream_msg_peek_server(const struct test_opts *opts) 298 + static void test_msg_peek_server(const struct test_opts *opts, 299 + bool seqpacket) 273 300 { 301 + unsigned char buf_half[MSG_PEEK_BUF_LEN / 2]; 302 + unsigned char buf_normal[MSG_PEEK_BUF_LEN]; 303 + unsigned char buf_peek[MSG_PEEK_BUF_LEN]; 304 + ssize_t res; 274 305 int fd; 275 306 276 - fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL); 307 + if (seqpacket) 308 + fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL); 309 + else 310 + fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL); 311 + 277 312 if (fd < 0) { 278 313 perror("accept"); 279 314 exit(EXIT_FAILURE); 280 315 } 281 316 282 - recv_byte(fd, 1, MSG_PEEK); 283 - recv_byte(fd, 1, 0); 317 + /* Peek from empty socket. */ 318 + res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK | MSG_DONTWAIT); 319 + if (res != -1) { 320 + fprintf(stderr, "expected recv(2) failure, got %zi\n", res); 321 + exit(EXIT_FAILURE); 322 + } 323 + 324 + if (errno != EAGAIN) { 325 + perror("EAGAIN expected"); 326 + exit(EXIT_FAILURE); 327 + } 328 + 329 + control_writeln("SRVREADY"); 330 + 331 + /* Peek part of data. */ 332 + res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK); 333 + if (res != sizeof(buf_half)) { 334 + fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n", 335 + sizeof(buf_half), res); 336 + exit(EXIT_FAILURE); 337 + } 338 + 339 + /* Peek whole data. */ 340 + res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK); 341 + if (res != sizeof(buf_peek)) { 342 + fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n", 343 + sizeof(buf_peek), res); 344 + exit(EXIT_FAILURE); 345 + } 346 + 347 + /* Compare partial and full peek. */ 348 + if (memcmp(buf_half, buf_peek, sizeof(buf_half))) { 349 + fprintf(stderr, "Partial peek data mismatch\n"); 350 + exit(EXIT_FAILURE); 351 + } 352 + 353 + if (seqpacket) { 354 + /* This type of socket supports MSG_TRUNC flag, 355 + * so check it with MSG_PEEK. We must get length 356 + * of the message. 357 + */ 358 + res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK | 359 + MSG_TRUNC); 360 + if (res != sizeof(buf_peek)) { 361 + fprintf(stderr, 362 + "recv(2) + MSG_PEEK | MSG_TRUNC, exp %zu, got %zi\n", 363 + sizeof(buf_half), res); 364 + exit(EXIT_FAILURE); 365 + } 366 + } 367 + 368 + res = recv(fd, buf_normal, sizeof(buf_normal), 0); 369 + if (res != sizeof(buf_normal)) { 370 + fprintf(stderr, "recv(2), expected %zu, got %zi\n", 371 + sizeof(buf_normal), res); 372 + exit(EXIT_FAILURE); 373 + } 374 + 375 + /* Compare full peek and normal read. */ 376 + if (memcmp(buf_peek, buf_normal, sizeof(buf_peek))) { 377 + fprintf(stderr, "Full peek data mismatch\n"); 378 + exit(EXIT_FAILURE); 379 + } 380 + 284 381 close(fd); 382 + } 383 + 384 + static void test_stream_msg_peek_client(const struct test_opts *opts) 385 + { 386 + return test_msg_peek_client(opts, false); 387 + } 388 + 389 + static void test_stream_msg_peek_server(const struct test_opts *opts) 390 + { 391 + return test_msg_peek_server(opts, false); 285 392 } 286 393 287 394 #define SOCK_BUF_SIZE (2 * 1024 * 1024) ··· 1160 1053 close(fd); 1161 1054 } 1162 1055 1056 + static void test_seqpacket_msg_peek_client(const struct test_opts *opts) 1057 + { 1058 + return test_msg_peek_client(opts, true); 1059 + } 1060 + 1061 + static void test_seqpacket_msg_peek_server(const struct test_opts *opts) 1062 + { 1063 + return test_msg_peek_server(opts, true); 1064 + } 1065 + 1163 1066 static struct test_case test_cases[] = { 1164 1067 { 1165 1068 .name = "SOCK_STREAM connection reset", ··· 1244 1127 .name = "SOCK_STREAM virtio skb merge", 1245 1128 .run_client = test_stream_virtio_skb_merge_client, 1246 1129 .run_server = test_stream_virtio_skb_merge_server, 1130 + }, 1131 + { 1132 + .name = "SOCK_SEQPACKET MSG_PEEK", 1133 + .run_client = test_seqpacket_msg_peek_client, 1134 + .run_server = test_seqpacket_msg_peek_server, 1247 1135 }, 1248 1136 {}, 1249 1137 };