this repo has no description
1const std = @import("std");
2const assert = std.debug.assert;
3const io = std.io;
4
5pub const BytePacketBuffer = struct {
6 buf: [512]u8 = undefined,
7 pos: usize = 0,
8
9 pub const ReadError = error{
10 EndOfBuffer,
11 JumpLimitExceeded,
12 };
13
14 pub const Reader = io.Reader(*BytePacketBuffer, ReadError, read);
15
16 /// Change the buffer position forward a specific number of steps
17 pub fn step(self: *BytePacketBuffer, pos: usize) void {
18 self.pos += pos;
19 }
20
21 /// Chanke the buffer position
22 pub fn seek(self: *BytePacketBuffer, pos: usize) void {
23 self.pos = pos;
24 }
25
26 pub fn reader(self: *BytePacketBuffer) Reader {
27 return .{ .context = self };
28 }
29
30 /// Read a single byte and move the position one step forward
31 pub fn read(self: *BytePacketBuffer, dest: []u8) ReadError!usize {
32 if (self.pos + dest.len > self.buf.len)
33 return ReadError.EndOfBuffer;
34 const size = dest.len;
35 const end = self.pos + size;
36 @memcpy(dest[0..size], self.buf[self.pos..end]);
37 self.pos = end;
38 return size;
39 }
40
41 /// Get a single byte without changing the buffer position
42 pub fn get(self: *const BytePacketBuffer, pos: usize) ReadError!u8 {
43 if (pos >= comptime self.buf.len) return ReadError.EndOfBuffer;
44 return self.buf[pos];
45 }
46
47 /// Get a range of bytes
48 pub fn get_range(self: *const BytePacketBuffer, start: usize, len: usize) ReadError![]const u8 {
49 if (start + len >= comptime self.buf.len) return ReadError.EndOfBuffer;
50 return self.buf[start .. start + len];
51 }
52
53 /// Read a qname
54 ///
55 /// The tricky part: Reading domain names, taking labels into consideration.
56 /// Will take something like [3]www[6]google[3]com and append
57 /// www.google.com to outstr.
58 pub fn read_qname(self: *BytePacketBuffer, outstr: []u8) ReadError!void {
59 // We might encounter jumps, therefore we need to keep thrack of our position locally
60 var pos = self.pos;
61 var out_pos: usize = 0;
62
63 // track whether or nor we've jumped
64 var jumped = false;
65 const max_jumps: usize = 5;
66 var jumps_performed: usize = 0;
67
68 var delim: ?[]const u8 = null;
69 while (true) {
70 if (jumps_performed > max_jumps) return ReadError.JumpLimitExceeded;
71
72 // Each label starts with a length byte
73 const len = try self.get(pos);
74
75 // If len has the two most signigicant bit set, it represents a jump to some other
76 // offset in the packet:
77 if ((len & 0xC0) == 0xC0) {
78 // Update the buffer position to a point past the current label
79 if (!jumped) self.seek(2);
80
81 // Read another byte, calculate offset and performe the jump by updating our
82 // local position variable
83 const b2 = @as(u16, try self.get(pos + 1));
84 const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2;
85 pos = @as(usize, offset);
86
87 // Indicate that a jump was performed
88 jumped = true;
89 jumps_performed += 1;
90
91 continue;
92 } else {
93 // Move a single byte forward to move path the length
94 self.pos += 1;
95 pos += 1;
96
97 // Domain names are terminated by an empty label of length 0, so if the length
98 // is zero we're done
99 if (len == 0) break;
100
101 if (delim) |del| {
102 @memcpy(outstr[out_pos .. out_pos + del.len], del);
103 out_pos += del.len;
104 }
105
106 const read_len = try self.read(outstr[out_pos .. out_pos + len]);
107 assert(read_len == len);
108
109 delim = ".";
110
111 pos += len;
112 out_pos += len;
113 }
114 }
115
116 if (!jumped) self.seek(1);
117 }
118};
119
120test "BytePacketBuffer.read" {
121 const testing = std.testing;
122 var buf = BytePacketBuffer{};
123 buf.buf[0] = 0x1;
124 try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big));
125}
126
127test "BytePacketBuffer.read_u16" {
128 const testing = std.testing;
129 var buf = BytePacketBuffer{};
130 buf.buf[0] = 0x1;
131 buf.buf[1] = 0x1;
132 try testing.expectEqual(0x101, try buf.reader().readInt(u16, .big));
133}
134
135test "BytePacketBuffer.read_u32" {
136 const testing = std.testing;
137 var buf = BytePacketBuffer{};
138 buf.buf[0] = 0x1;
139 buf.buf[1] = 0x1;
140 buf.buf[2] = 0x1;
141 buf.buf[3] = 0x1;
142 try testing.expectEqual(0x1010101, try buf.reader().readInt(u32, .big));
143}
144
145test "BytePacketBuffer.read last byte" {
146 const testing = std.testing;
147 var buf = BytePacketBuffer{};
148 buf.buf[buf.buf.len - 1] = 0x1;
149 buf.pos = buf.buf.len - 1;
150 try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big));
151 try testing.expectError(
152 BytePacketBuffer.ReadError.EndOfBuffer,
153 buf.reader().readInt(u8, .big),
154 );
155}
156
157test "BytePacketBuffer.read_qname google.com" {
158 const testing = std.testing;
159 const allocator = testing.allocator;
160
161 const input = [_]u8{
162 0x06, // [6]
163 0x67, // g
164 0x6f, // o
165 0x6f, // o
166 0x67, // g
167 0x6c, // l
168 0x65, // e
169 0x03, // [3]
170 0x63, // c
171 0x6f, // o
172 0x6d, // m
173 0x00, // [0]
174 };
175 const expected = "google.com";
176
177 var buf = BytePacketBuffer{};
178 for (input, 0..) |char, idx| {
179 buf.buf[idx] = char;
180 }
181
182 const outstr = try allocator.alloc(u8, expected.len);
183 defer allocator.free(outstr);
184
185 try buf.read_qname(outstr);
186
187 try testing.expectEqualStrings(expected, outstr);
188}