this repo has no description
1const std = @import("std");
2const assert = std.debug.assert;
3const io = std.io;
4
5pub const BytePacketBuffer = struct {
6 buf: [512]u8 = undefined,
7 pos: usize = 0,
8
9 pub const ReadError = error{EndOfBuffer};
10
11 pub const Reader = io.Reader(*BytePacketBuffer, ReadError, read);
12
13 /// Change the buffer position forward a specific number of steps
14 pub fn step(self: *BytePacketBuffer, pos: usize) void {
15 self.pos += pos;
16 }
17
18 /// Chanke the buffer position
19 pub fn seek(self: *BytePacketBuffer, pos: usize) void {
20 self.pos = pos;
21 }
22
23 pub fn reader(self: *BytePacketBuffer) Reader {
24 return .{ .context = self };
25 }
26
27 /// Read a single byte and move the position one step forward
28 pub fn read(self: *BytePacketBuffer, dest: []u8) ReadError!usize {
29 const size = @min(dest.len, self.buf.len - self.pos);
30 const end = self.pos + size;
31
32 @memcpy(dest[0..size], self.buf[self.pos..end]);
33 self.pos = end;
34
35 return size;
36 }
37
38 /// Get a single byte without changing the buffer position
39 pub fn get(self: *const BytePacketBuffer, pos: usize) ReadError!u8 {
40 if (pos >= comptime self.buf.len) return ReadError.EndOfBuffer;
41 return self.buf[pos];
42 }
43
44 /// Get a range of bytes
45 pub fn get_range(self: *const BytePacketBuffer, start: usize, len: usize) ReadError![]const u8 {
46 if (start + len >= comptime self.buf.len) return ReadError.EndOfBuffer;
47 return self.buf[start .. start + len];
48 }
49
50 /// Read a qname
51 ///
52 /// The tricky part: Reading domain names, taking labels into consideration.
53 /// Will take something like [3]www[6]google[3]com and append
54 /// www.google.com to outstr.
55 pub fn read_qname(self: *BytePacketBuffer, outstr: []u8) !void {
56 // We might encounter jumps, therefore we need to keep thrack of our position locally
57 var pos = self.pos;
58 var out_pos: usize = 0;
59
60 // track whether or nor we've jumped
61 var jumped = false;
62 const max_jumps: usize = 5;
63 var jumps_performed: usize = 0;
64
65 var delim: ?[]const u8 = null;
66 while (true) {
67 if (jumps_performed > max_jumps) return error.JumpLimitExceeded;
68
69 // Each label starts with a length byte
70 const len = try self.get(pos);
71
72 // If len has the two most signigicant bit set, it represents a jump to some other
73 // offset in the packet:
74 if ((len & 0xC0) == 0xC0) {
75 // Update the buffer position to a point past the current label
76 if (!jumped) self.seek(2);
77
78 // Read another byte, calculate offset and performe the jump by updating our
79 // local position variable
80 const b2 = @as(u16, try self.get(pos + 1));
81 const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2;
82 pos = @as(usize, offset);
83
84 // Indicate that a jump was performed
85 jumped = true;
86 jumps_performed += 1;
87
88 continue;
89 } else {
90 // Move a single byte forward to move path the length
91 self.pos += 1;
92 pos += 1;
93
94 // Domain names are terminated by an empty label of length 0, so if the length
95 // is zero we're done
96 if (len == 0) break;
97
98 if (delim) |del| {
99 @memcpy(outstr[out_pos .. out_pos + del.len], del);
100 out_pos += del.len;
101 }
102
103 const read_len = try self.read(outstr[out_pos .. out_pos + len]);
104 assert(read_len == len);
105
106 delim = ".";
107
108 pos += len;
109 out_pos += len;
110 }
111 }
112
113 if (!jumped) self.seek(1);
114 }
115};
116
117test "BytePacketBuffer.read" {
118 const testing = std.testing;
119 var buf = BytePacketBuffer{};
120 buf.buf[0] = 0x1;
121 try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big));
122}
123
124test "BytePacketBuffer.read_u16" {
125 const testing = std.testing;
126 var buf = BytePacketBuffer{};
127 buf.buf[0] = 0x1;
128 buf.buf[1] = 0x1;
129 try testing.expectEqual(0x101, try buf.reader().readInt(u16, .big));
130}
131
132test "BytePacketBuffer.read_u32" {
133 const testing = std.testing;
134 var buf = BytePacketBuffer{};
135 buf.buf[0] = 0x1;
136 buf.buf[1] = 0x1;
137 buf.buf[2] = 0x1;
138 buf.buf[3] = 0x1;
139 try testing.expectEqual(0x1010101, try buf.reader().readInt(u32, .big));
140}
141
142test "BytePacketBuffer.read_qname" {
143 const testing = std.testing;
144 const allocator = testing.allocator;
145
146 const input = [_]u8{ 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00 };
147 var buf = BytePacketBuffer{};
148
149 for (input, 0..) |char, idx| {
150 buf.buf[idx] = char;
151 }
152
153 const expected = "google.com";
154 const outstr = try allocator.alloc(u8, expected.len);
155 defer allocator.free(outstr);
156
157 try buf.read_qname(outstr);
158
159 try testing.expectEqualStrings(expected, outstr);
160}