this repo has no description
1const std = @import("std");
2const assert = std.debug.assert;
3const io = std.io;
4
5pub const BytePacketBuffer = struct {
6 buf: [512]u8 = undefined,
7 pos: usize = 0,
8
9 pub const ReadError = error{
10 EndOfBuffer,
11 JumpLimitExceeded,
12 };
13
14 pub const Reader = io.Reader(*BytePacketBuffer, ReadError, read);
15
16 /// Change the buffer position forward a specific number of steps
17 pub fn step(self: *BytePacketBuffer, pos: usize) void {
18 self.pos += pos;
19 }
20
21 /// Chanke the buffer position
22 pub fn seek(self: *BytePacketBuffer, pos: usize) void {
23 self.pos = pos;
24 }
25
26 pub fn reader(self: *BytePacketBuffer) Reader {
27 return .{ .context = self };
28 }
29
30 /// Read a single byte and move the position one step forward
31 pub fn read(self: *BytePacketBuffer, dest: []u8) ReadError!usize {
32 if (self.pos + dest.len > self.buf.len)
33 return ReadError.EndOfBuffer;
34 const size = dest.len;
35 const end = self.pos + size;
36 @memcpy(dest[0..size], self.buf[self.pos..end]);
37 self.pos = end;
38 return size;
39 }
40
41 /// Get a single byte without changing the buffer position
42 pub fn get(self: *const BytePacketBuffer, pos: usize) ReadError!u8 {
43 if (pos >= comptime self.buf.len) return ReadError.EndOfBuffer;
44 return self.buf[pos];
45 }
46
47 /// Get a range of bytes
48 pub fn getRange(self: *const BytePacketBuffer, start: usize, len: usize) ReadError![]const u8 {
49 if (start + len >= comptime self.buf.len) return ReadError.EndOfBuffer;
50 return self.buf[start .. start + len];
51 }
52
53 /// Read a qname
54 ///
55 /// The tricky part: Reading domain names, taking labels into consideration.
56 /// Will take something like [3]www[6]google[3]com and append
57 /// www.google.com to outstr.
58 pub fn readQname(self: *BytePacketBuffer, outstr: []u8) ReadError!void {
59 // We might encounter jumps, therefore we need to keep thrack of our position locally
60 var pos = self.pos;
61 var out_pos: usize = 0;
62
63 // track whether or nor we've jumped
64 var jumped = false;
65 const max_jumps: usize = 5;
66 var jumps_performed: usize = 0;
67
68 var delim: ?[]const u8 = null;
69 while (true) {
70 if (jumps_performed > max_jumps) return ReadError.JumpLimitExceeded;
71
72 // Each label starts with a length byte
73 const len = try self.get(pos);
74
75 // If len has the two most signigicant bit set, it represents a jump to some other
76 // offset in the packet:
77 if ((len & 0xC0) == 0xC0) {
78 // Update the buffer position to a point past the current label
79 if (!jumped) self.seek(2);
80
81 // Read another byte, calculate offset and performe the jump by updating our
82 // local position variable
83 const b2 = @as(u16, try self.get(pos + 1));
84 const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2;
85 pos = @as(usize, offset);
86
87 // Indicate that a jump was performed
88 jumped = true;
89 jumps_performed += 1;
90
91 continue;
92 } else {
93 // Move a single byte forward to move path the length
94 self.pos += 1;
95 pos += 1;
96
97 // Domain names are terminated by an empty label of length 0, so if the length
98 // is zero we're done
99 if (len == 0) break;
100
101 if (delim) |del| {
102 @memcpy(outstr[out_pos .. out_pos + del.len], del);
103 out_pos += del.len;
104 }
105
106 const read_len = try self.read(outstr[out_pos .. out_pos + len]);
107 assert(read_len == len);
108
109 delim = ".";
110
111 pos += len;
112 out_pos += len;
113 }
114 }
115
116 if (!jumped) self.seek(1);
117 }
118
119 /// Read a qname
120 ///
121 /// The tricky part: Reading domain names, taking labels into consideration.
122 /// Will take something like [3]www[6]google[3]com and append
123 /// www.google.com to outstr.
124 pub fn readQnameAlloc(self: *BytePacketBuffer, alloc: std.mem.Allocator) ![]u8 {
125 var buffer: std.ArrayList(u8) = .init(alloc);
126
127 // We might encounter jumps, therefore we need to keep thrack of our position locally
128 var pos = self.pos;
129
130 // track whether or nor we've jumped
131 var jumped = false;
132 const max_jumps: usize = 5;
133 var jumps_performed: usize = 0;
134
135 var delim: ?[]const u8 = null;
136 while (true) {
137 if (jumps_performed > max_jumps) return ReadError.JumpLimitExceeded;
138
139 // Each label starts with a length byte
140 const len = try self.get(pos);
141
142 // If len has the two most signigicant bit set, it represents a jump to some other
143 // offset in the packet:
144 if ((len & 0xC0) == 0xC0) {
145 // Update the buffer position to a point past the current label
146 if (!jumped) self.seek(2);
147
148 // Read another byte, calculate offset and performe the jump by updating our
149 // local position variable
150 const b2 = @as(u16, try self.get(pos + 1));
151 const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2;
152 pos = @as(usize, offset);
153
154 // Indicate that a jump was performed
155 jumped = true;
156 jumps_performed += 1;
157
158 continue;
159 } else {
160 // Move a single byte forward to move path the length
161 self.pos += 1;
162 pos += 1;
163
164 // Domain names are terminated by an empty label of length 0, so if the length
165 // is zero we're done
166 if (len == 0) break;
167
168 if (delim) |del| {
169 try buffer.appendSlice(del);
170 }
171
172 try buffer.appendSlice(try self.getRange(pos, len));
173
174 delim = ".";
175
176 pos += len;
177 }
178 }
179
180 if (!jumped) self.seek(1);
181
182 return buffer.toOwnedSlice();
183 }
184};
185
186test "BytePacketBuffer.read" {
187 const testing = std.testing;
188 var buf = BytePacketBuffer{};
189 buf.buf[0] = 0x1;
190 try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big));
191}
192
193test "BytePacketBuffer.read_u16" {
194 const testing = std.testing;
195 var buf = BytePacketBuffer{};
196 buf.buf[0] = 0x1;
197 buf.buf[1] = 0x1;
198 try testing.expectEqual(0x101, try buf.reader().readInt(u16, .big));
199}
200
201test "BytePacketBuffer.read_u32" {
202 const testing = std.testing;
203 var buf = BytePacketBuffer{};
204 buf.buf[0] = 0x1;
205 buf.buf[1] = 0x1;
206 buf.buf[2] = 0x1;
207 buf.buf[3] = 0x1;
208 try testing.expectEqual(0x1010101, try buf.reader().readInt(u32, .big));
209}
210
211test "BytePacketBuffer.read last byte" {
212 const testing = std.testing;
213 var buf = BytePacketBuffer{};
214 buf.buf[buf.buf.len - 1] = 0x1;
215 buf.pos = buf.buf.len - 1;
216 try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big));
217 try testing.expectError(
218 BytePacketBuffer.ReadError.EndOfBuffer,
219 buf.reader().readInt(u8, .big),
220 );
221}
222
223test "BytePacketBuffer.read_qname google.com" {
224 const testing = std.testing;
225 const allocator = testing.allocator;
226
227 const input = [_]u8{
228 0x06, // [6]
229 0x67, // g
230 0x6f, // o
231 0x6f, // o
232 0x67, // g
233 0x6c, // l
234 0x65, // e
235 0x03, // [3]
236 0x63, // c
237 0x6f, // o
238 0x6d, // m
239 0x00, // [0]
240 };
241 const expected = "google.com";
242
243 var buf = BytePacketBuffer{};
244 for (input, 0..) |char, idx| {
245 buf.buf[idx] = char;
246 }
247
248 const outstr = try allocator.alloc(u8, expected.len);
249 defer allocator.free(outstr);
250
251 try buf.readQname(outstr);
252
253 try testing.expectEqualStrings(expected, outstr);
254}
255
256test "BytePacketBuffer.read_qname_alloc google.com" {
257 const testing = std.testing;
258 const allocator = testing.allocator;
259
260 const input = [_]u8{
261 0x06, // [6]
262 0x67, // g
263 0x6f, // o
264 0x6f, // o
265 0x67, // g
266 0x6c, // l
267 0x65, // e
268 0x03, // [3]
269 0x63, // c
270 0x6f, // o
271 0x6d, // m
272 0x00, // [0]
273 };
274 const expected = "google.com";
275
276 var buf = BytePacketBuffer{};
277 for (input, 0..) |char, idx| {
278 buf.buf[idx] = char;
279 }
280
281 const outstr = try buf.readQnameAlloc(allocator);
282 defer allocator.free(outstr);
283
284 try testing.expectEqualStrings(expected, outstr);
285}
286