From fe68ae96d6ef049d34c3482c3fe0df970dc41cca Mon Sep 17 00:00:00 2001 From: alex Date: Wed, 8 Feb 2023 16:19:09 +0100 Subject: [PATCH] comm: format protocol header manually and eliminate a potential footgun the nd/ngui protocol is a simple TLV: message tag, length and a json-encoded payload. the issue is with the (tag, length) header which was written/read by using raw memory pointers, such as Writer.writeStruct. in addition, tag ordinal values were auto-generated by the compiler. all those deficiencies may potentially lead to bugs hard to find or serious data leaks. serializing and deserializing the (tag, length) header manually makes it more robust and eliminates the footgun. --- src/comm.zig | 88 ++++++++++++++++++++++++++++----------------------- src/types.zig | 4 +++ 2 files changed, 53 insertions(+), 39 deletions(-) diff --git a/src/comm.zig b/src/comm.zig index 70d88b1..6b4031f 100644 --- a/src/comm.zig +++ b/src/comm.zig @@ -1,8 +1,20 @@ -///! daemon <-> gui communication +///! daemon/gui communication. +///! the protocol is a simple TLV construct: MessageTag(u16), length(u64), json-marshalled Message; +///! little endian. const std = @import("std"); const json = std.json; const mem = std.mem; +const ByteArrayList = @import("types.zig").ByteArrayList; + +/// common errors returned by read/write functions. +pub const Error = error{ + CommReadInvalidTag, + CommReadZeroLenInNonVoidTag, + CommWriteTooLarge, +}; + +/// daemon and gui exchange messages of this type. pub const Message = union(MessageTag) { ping: void, pong: void, @@ -27,43 +39,42 @@ pub const Message = union(MessageTag) { }; }; -pub const MessageTag = enum(u8) { - ping, - pong, - poweroff, - wifi_connect, - network_report, - get_network_report, -}; - -const Header = extern struct { - tag: MessageTag, - len: usize, +/// it is important to preserve ordinal values for future compatiblity, +/// especially when nd and gui may temporary diverge in their implementations. +pub const MessageTag = enum(u16) { + ping = 0x01, + pong = 0x02, + poweroff = 0x03, + wifi_connect = 0x04, + network_report = 0x05, + get_network_report = 0x06, + // next: 0x07 }; /// reads and parses a single message from the input stream reader. /// callers must deallocate resources with free when done. -pub fn read(allocator: mem.Allocator, reader: anytype) anyerror!Message { - const h = try reader.readStruct(Header); - if (h.len == 0) { - const m = switch (h.tag) { +pub fn read(allocator: mem.Allocator, reader: anytype) !Message { + // alternative is @intToEnum(reader.ReadIntLittle(u16)) but it may panic. + const tag = reader.readEnum(MessageTag, .Little) catch { + return Error.CommReadInvalidTag; + }; + const len = try reader.readIntLittle(u64); + if (len == 0) { + return switch (tag) { .ping => Message{ .ping = {} }, .pong => Message{ .pong = {} }, .poweroff => Message{ .poweroff = {} }, - else => error.ZeroLenInNonVoidTag, + else => Error.CommReadZeroLenInNonVoidTag, }; - return m; } - // TODO: limit h.len to some max value - var bytes = try allocator.alloc(u8, h.len); + var bytes = try allocator.alloc(u8, len); defer allocator.free(bytes); try reader.readNoEof(bytes); - const jopt = json.ParseOptions{ .allocator = allocator, .ignore_unknown_fields = true }; var jstream = json.TokenStream.init(bytes); - return switch (h.tag) { - .ping, .pong, .poweroff => unreachable, // void + return switch (tag) { + .ping, .pong, .poweroff => unreachable, // handled above .wifi_connect => Message{ .wifi_connect = try json.parse(Message.WifiConnect, &jstream, jopt), }, @@ -79,30 +90,27 @@ pub fn read(allocator: mem.Allocator, reader: anytype) anyerror!Message { /// outputs the message msg using writer. /// all allocated resources are freed upon return. pub fn write(allocator: mem.Allocator, writer: anytype, msg: Message) !void { - var header = Header{ .tag = msg, .len = 0 }; - switch (msg) { - .ping, .pong, .poweroff => return writer.writeStruct(header), - else => {}, // non-zero payload; continue - } - - var data = std.ArrayList(u8).init(allocator); - defer data.deinit(); const jopt = .{ .whitespace = null }; + var data = ByteArrayList.init(allocator); + defer data.deinit(); switch (msg) { - .ping, .pong, .poweroff => unreachable, + .ping, .pong, .poweroff => {}, // zero length payload .wifi_connect => try json.stringify(msg.wifi_connect, jopt, data.writer()), .network_report => try json.stringify(msg.network_report, jopt, data.writer()), .get_network_report => try json.stringify(msg.get_network_report, jopt, data.writer()), } + if (data.items.len > std.math.maxInt(u64)) { + return Error.CommWriteTooLarge; + } - header.len = data.items.len; - try writer.writeStruct(header); + try writer.writeIntLittle(u16, @enumToInt(msg)); + try writer.writeIntLittle(u64, data.items.len); try writer.writeAll(data.items); } pub fn free(allocator: mem.Allocator, m: Message) void { switch (m) { - .ping, .pong, .poweroff => {}, + .ping, .pong, .poweroff => {}, // zero length payload else => |v| { json.parseFree(@TypeOf(v), v, .{ .allocator = allocator }); }, @@ -119,7 +127,8 @@ test "read" { var buf = std.ArrayList(u8).init(t.allocator); defer buf.deinit(); - try buf.writer().writeStruct(Header{ .tag = msg, .len = data.items.len }); + try buf.writer().writeIntLittle(u16, @enumToInt(msg)); + try buf.writer().writeIntLittle(u64, data.items.len); try buf.writer().writeAll(data.items); var bs = std.io.fixedBufferStream(buf.items); @@ -141,10 +150,11 @@ test "write" { const payload = "{\"ssid\":\"wlan\",\"password\":\"secret\"}"; var js = std.ArrayList(u8).init(t.allocator); defer js.deinit(); - try js.writer().writeStruct(Header{ .tag = msg, .len = payload.len }); + try js.writer().writeIntLittle(u16, @enumToInt(msg)); + try js.writer().writeIntLittle(u64, payload.len); try js.appendSlice(payload); - try t.expectEqualSlices(u8, js.items, buf.items); + try t.expectEqualStrings(js.items, buf.items); } test "write/read void tags" { diff --git a/src/types.zig b/src/types.zig index ad4de50..3330a83 100644 --- a/src/types.zig +++ b/src/types.zig @@ -1,6 +1,10 @@ const std = @import("std"); const builtin = @import("builtin"); +/// prefer this type over the std.ArrayList(u8) just to ensure consistency +/// and potential regressions. For example, comm module uses it for read/write. +pub const ByteArrayList = std.ArrayList(u8); + pub const Timer = if (builtin.is_test) TestTimer else std.time.Timer; /// TestTimer always reports the same fixed value.