From 3eafcab64ecaf8d00a9264b441e996825a6a31bd Mon Sep 17 00:00:00 2001 From: Lars Müller <34514239+appgurueu@users.noreply.github.com> Date: Sat, 11 Jun 2022 20:00:26 +0200 Subject: Builtin: Redo serialize.lua (#11427) Features: * Support for arbitrary references, including self-referencing * Short output, references "long" strings as a bonus * Around the same speed, potentially slower if long, short keys are present * Properly works with NaN and inf --- builtin/common/serialize.lua | 349 ++++++++++++++++---------------- builtin/common/tests/serialize_spec.lua | 155 ++++++++++++-- 2 files changed, 315 insertions(+), 189 deletions(-) (limited to 'builtin/common') diff --git a/builtin/common/serialize.lua b/builtin/common/serialize.lua index 300b394c6..6278e2739 100644 --- a/builtin/common/serialize.lua +++ b/builtin/common/serialize.lua @@ -1,205 +1,214 @@ --- Lua module to serialize values as Lua code. --- From: https://github.com/fab13n/metalua/blob/no-dll/src/lib/serialize.lua +-- From: https://github.com/appgurueu/modlib/blob/master/luon.lua -- License: MIT --- @copyright 2006-2997 Fabien Fleutot --- @author Fabien Fleutot --- @author ShadowNinja --------------------------------------------------------------------------------- ---- Serialize an object into a source code string. This string, when passed as --- an argument to deserialize(), returns an object structurally identical to --- the original one. The following are currently supported: --- * Booleans, numbers, strings, and nil. --- * Functions; uses interpreter-dependent (and sometimes platform-dependent) bytecode! --- * Tables; they can cantain multiple references and can be recursive, but metatables aren't saved. --- This works in two phases: --- 1. Recursively find and record multiple references and recursion. --- 2. Recursively dump the value into a string. --- @param x Value to serialize (nil is allowed). --- @return load()able string containing the value. -function core.serialize(x) - local local_index = 1 -- Top index of the "_" local table in the dump - -- table->nil/1/2 set of tables seen. - -- nil = not seen, 1 = seen once, 2 = seen multiple times. - local seen = {} +local next, rawget, pairs, pcall, error, type, setfenv, loadstring + = next, rawget, pairs, pcall, error, type, setfenv, loadstring - -- nest_points are places where a table appears within itself, directly - -- or not. For instance, all of these chunks create nest points in - -- table x: "x = {}; x[x] = 1", "x = {}; x[1] = x", - -- "x = {}; x[1] = {y = {x}}". - -- To handle those, two tables are used by mark_nest_point: - -- * nested - Transient set of tables being currently traversed. - -- Used for detecting nested tables. - -- * nest_points - parent->{key=value, ...} table cantaining the nested - -- keys and values in the parent. They're all dumped after all the - -- other table operations have been performed. - -- - -- mark_nest_point(p, k, v) fills nest_points with information required - -- to remember that key/value (k, v) creates a nest point in table - -- parent. It also marks "parent" and the nested item(s) as occuring - -- multiple times, since several references to it will be required in - -- order to patch the nest points. - local nest_points = {} - local nested = {} - local function mark_nest_point(parent, k, v) - local nk, nv = nested[k], nested[v] - local np = nest_points[parent] - if not np then - np = {} - nest_points[parent] = np - end - np[k] = v - seen[parent] = 2 - if nk then seen[k] = 2 end - if nv then seen[v] = 2 end - end +local table_concat, string_dump, string_format, string_match, math_huge + = table.concat, string.dump, string.format, string.match, math.huge - -- First phase, list the tables and functions which appear more than - -- once in x. - local function mark_multiple_occurences(x) - local tp = type(x) - if tp ~= "table" and tp ~= "function" then - -- No identity (comparison is done by value, not by instance) +-- Recursively counts occurences of objects (non-primitives including strings) in a table. +local function count_objects(value) + local counts = {} + if value == nil then + -- Early return for nil; tables can't contain nil + return counts + end + local function count_values(val) + local type_ = type(val) + if type_ == "boolean" or type_ == "number" then return end - if seen[x] == 1 then - seen[x] = 2 - elseif seen[x] ~= 2 then - seen[x] = 1 - end - - if tp == "table" then - nested[x] = true - for k, v in pairs(x) do - if nested[k] or nested[v] then - mark_nest_point(x, k, v) - else - mark_multiple_occurences(k) - mark_multiple_occurences(v) + local count = counts[val] + counts[val] = (count or 0) + 1 + if type_ == "table" then + if not count then + for k, v in pairs(val) do + count_values(k) + count_values(v) end end - nested[x] = nil + elseif type_ ~= "string" and type_ ~= "function" then + error("unsupported type: " .. type_) end end + count_values(value) + return counts +end - local dumped = {} -- object->varname set - local local_defs = {} -- Dumped local definitions as source code lines +-- Build a "set" of Lua keywords. These can't be used as short key names. +-- See https://www.lua.org/manual/5.1/manual.html#2.1 +local keywords = {} +for _, keyword in pairs({ + "and", "break", "do", "else", "elseif", + "end", "false", "for", "function", "if", + "in", "local", "nil", "not", "or", + "repeat", "return", "then", "true", "until", "while", + "goto" -- LuaJIT, Lua 5.2+ +}) do + keywords[keyword] = true +end - -- Mutually recursive local functions: - local dump_val, dump_or_ref_val +local function quote(string) + return string_format("%q", string) +end - -- If x occurs multiple times, dump the local variable rather than - -- the value. If it's the first time it's dumped, also dump the - -- content in local_defs. - function dump_or_ref_val(x) - if seen[x] ~= 2 then - return dump_val(x) - end - local var = dumped[x] - if var then -- Already referenced - return var +local function dump_func(func) + return string_format("loadstring(%q)", string_dump(func)) +end + +-- Serializes Lua nil, booleans, numbers, strings, tables and even functions +-- Tables are referenced by reference, strings are referenced by value. Supports circular tables. +local function serialize(value, write) + local reference, refnum = "r1", 1 + -- [object] = reference string + local references = {} + -- Circular tables that must be filled using `table[key] = value` statements + local to_fill = {} + for object, count in pairs(count_objects(value)) do + local type_ = type(object) + -- Object must appear more than once. If it is a string, the reference has to be shorter than the string. + if count >= 2 and (type_ ~= "string" or #reference + 2 < #object) then + write(reference) + write("=") + if type_ == "table" then + write("{}") + elseif type_ == "function" then + write(dump_func(object)) + elseif type_ == "string" then + write(quote(object)) + end + write(";") + references[object] = reference + if type_ == "table" then + to_fill[object] = reference + end + refnum = refnum + 1 + reference = ("r%X"):format(refnum) end - -- First occurence, create and register reference - local val = dump_val(x) - local i = local_index - local_index = local_index + 1 - var = "_["..i.."]" - local_defs[#local_defs + 1] = var.." = "..val - dumped[x] = var - return var end - - -- Second phase. Dump the object; subparts occuring multiple times - -- are dumped in local variables which can be referenced multiple - -- times. Care is taken to dump local vars in a sensible order. - function dump_val(x) - local tp = type(x) - if x == nil then return "nil" - elseif tp == "string" then return string.format("%q", x) - elseif tp == "boolean" then return x and "true" or "false" - elseif tp == "function" then - return string.format("loadstring(%q)", string.dump(x)) - elseif tp == "number" then - -- Serialize numbers reversibly with string.format - return string.format("%.17g", x) - elseif tp == "table" then - local vals = {} - local idx_dumped = {} - local np = nest_points[x] - for i, v in ipairs(x) do - if not np or not np[i] then - vals[#vals + 1] = dump_or_ref_val(v) - end - idx_dumped[i] = true + -- Used to decide whether we should do "key=..." + local function use_short_key(key) + return not references[key] and type(key) == "string" and (not keywords[key]) and string_match(key, "^[%a_][%a%d_]*$") + end + local function dump(value) + -- Primitive types + if value == nil then + return write("nil") + end + if value == true then + return write("true") + end + if value == false then + return write("false") + end + local type_ = type(value) + if type_ == "number" then + return write(string_format("%.17g", value)) + end + -- Reference types: table, function and string + local ref = references[value] + if ref then + return write(ref) + end + if type_ == "string" then + return write(quote(value)) + end + if type_ == "function" then + return write(dump_func(value)) + end + if type_ == "table" then + write("{") + -- First write list keys: + -- Don't use the table length #value here as it may horribly fail + -- for tables which use large integers as keys in the hash part; + -- stop at the first "hole" (nil value) instead + local len = 0 + local first = true -- whether this is the first entry, which may not have a leading comma + while true do + local v = rawget(value, len + 1) -- use rawget to avoid metatables like the vector metatable + if v == nil then break end + if first then first = false else write(",") end + dump(v) + len = len + 1 end - for k, v in pairs(x) do - if (not np or not np[k]) and - not idx_dumped[k] then - vals[#vals + 1] = "["..dump_or_ref_val(k).."] = " - ..dump_or_ref_val(v) + -- Now write map keys ([key] = value) + for k, v in next, value do + -- We have written all non-float keys in [1, len] already + if type(k) ~= "number" or k % 1 ~= 0 or k < 1 or k > len then + if first then first = false else write(",") end + if use_short_key(k) then + write(k) + else + write("[") + dump(k) + write("]") + end + write("=") + dump(v) end end - return "{"..table.concat(vals, ", ").."}" - else - error("Can't serialize data of type "..tp) + write("}") + return end end - - local function dump_nest_points() - for parent, vals in pairs(nest_points) do - for k, v in pairs(vals) do - local_defs[#local_defs + 1] = dump_or_ref_val(parent) - .."["..dump_or_ref_val(k).."] = " - ..dump_or_ref_val(v) + -- Write the statements to fill circular tables + for table, ref in pairs(to_fill) do + for k, v in pairs(table) do + write(ref) + if use_short_key(k) then + write(".") + write(k) + else + write("[") + dump(k) + write("]") end + write("=") + dump(v) + write(";") end end - - mark_multiple_occurences(x) - local top_level = dump_or_ref_val(x) - dump_nest_points() - - if next(local_defs) then - return "local _ = {}\n" - ..table.concat(local_defs, "\n") - .."\nreturn "..top_level - else - return "return "..top_level - end + write("return ") + dump(value) end --- Deserialization - -local function safe_loadstring(...) - local func, err = loadstring(...) - if func then - setfenv(func, {}) - return func - end - return nil, err +function core.serialize(value) + local rope = {} + serialize(value, function(text) + -- Faster than table.insert(rope, text) on PUC Lua 5.1 + rope[#rope + 1] = text + end) + return table_concat(rope) end local function dummy_func() end -function core.deserialize(str, safe) - if type(str) ~= "string" then - return nil, "Cannot deserialize type '"..type(str) - .."'. Argument must be a string." - end - if str:byte(1) == 0x1B then - return nil, "Bytecode prohibited" - end - local f, err = loadstring(str) - if not f then return nil, err end +local nan = (0/0)^1 -- +nan - -- The environment is recreated every time so deseralized code cannot - -- pollute it with permanent references. - setfenv(f, {loadstring = safe and dummy_func or safe_loadstring}) +function core.deserialize(str, safe) + local func, err = loadstring(str) + if not func then return nil, err end - local good, data = pcall(f) - if good then - return data + -- math.huge is serialized to inf, NaNs are serialized to nan by Lua + local env = {inf = math_huge, nan = nan} + if safe then + env.loadstring = dummy_func else - return nil, data + env.loadstring = function(str, ...) + local func, err = loadstring(str, ...) + if func then + setfenv(func, env) + return func + end + return nil, err + end + end + setfenv(func, env) + local success, value_or_err = pcall(func) + if success then + return value_or_err end + return nil, value_or_err end diff --git a/builtin/common/tests/serialize_spec.lua b/builtin/common/tests/serialize_spec.lua index 69b2b567c..ea79680d7 100644 --- a/builtin/common/tests/serialize_spec.lua +++ b/builtin/common/tests/serialize_spec.lua @@ -6,38 +6,92 @@ _G.setfenv = require 'busted.compatibility'.setfenv dofile("builtin/common/serialize.lua") dofile("builtin/common/vector.lua") +-- Supports circular tables; does not support table keys +-- Correctly checks whether a mapping of references ("same") exists +-- Is significantly more efficient than assert.same +local function assert_same(a, b, same) + same = same or {} + if same[a] or same[b] then + assert(same[a] == b and same[b] == a) + return + end + if a == b then + return + end + if type(a) ~= "table" or type(b) ~= "table" then + assert(a == b) + return + end + same[a] = b + same[b] = a + local count = 0 + for k, v in pairs(a) do + count = count + 1 + assert(type(k) ~= "table") + assert_same(v, b[k], same) + end + for _ in pairs(b) do + count = count - 1 + end + assert(count == 0) +end + +local x, y = {}, {} +local t1, t2 = {x, x, y, y}, {x, y, x, y} +assert.same(t1, t2) -- will succeed because it only checks whether the depths match +assert(not pcall(assert_same, t1, t2)) -- will correctly fail because it checks whether the refs match + describe("serialize", function() + local function assert_preserves(value) + local preserved_value = core.deserialize(core.serialize(value)) + assert_same(value, preserved_value) + end it("works", function() - local test_in = {cat={sound="nyan", speed=400}, dog={sound="woof"}} - local test_out = core.deserialize(core.serialize(test_in)) - - assert.same(test_in, test_out) + assert_preserves({cat={sound="nyan", speed=400}, dog={sound="woof"}}) end) it("handles characters", function() - local test_in = {escape_chars="\n\r\t\v\\\"\'", non_european="θשׁ٩∂"} - local test_out = core.deserialize(core.serialize(test_in)) - assert.same(test_in, test_out) + assert_preserves({escape_chars="\n\r\t\v\\\"\'", non_european="θשׁ٩∂"}) + end) + + it("handles NaN & infinities", function() + local nan = core.deserialize(core.serialize(0/0)) + assert(nan ~= nan) + assert_preserves(math.huge) + assert_preserves(-math.huge) end) it("handles precise numbers", function() - local test_in = 0.2695949158945771 - local test_out = core.deserialize(core.serialize(test_in)) - assert.same(test_in, test_out) + assert_preserves(0.2695949158945771) end) it("handles big integers", function() - local test_in = 269594915894577 - local test_out = core.deserialize(core.serialize(test_in)) - assert.same(test_in, test_out) + assert_preserves(269594915894577) end) it("handles recursive structures", function() local test_in = { hello = "world" } test_in.foo = test_in + assert_preserves(test_in) + end) + + it("handles cross-referencing structures", function() + local test_in = { + foo = { + baz = { + {} + }, + }, + bar = { + baz = {}, + }, + } - local test_out = core.deserialize(core.serialize(test_in)) - assert.same(test_in, test_out) + test_in.foo.baz[1].foo = test_in.foo + test_in.foo.baz[1].bar = test_in.bar + test_in.bar.baz[1] = test_in.foo.baz[1] + + assert_preserves(test_in) end) it("strips functions in safe mode", function() @@ -47,6 +101,7 @@ describe("serialize", function() end, foo = "bar" } + setfenv(test_in.func, _G) local str = core.serialize(test_in) assert.not_nil(str:find("loadstring")) @@ -58,13 +113,75 @@ describe("serialize", function() it("vectors work", function() local v = vector.new(1, 2, 3) - assert.same({{x = 1, y = 2, z = 3}}, core.deserialize(core.serialize({v}))) - assert.same({x = 1, y = 2, z = 3}, core.deserialize(core.serialize(v))) + assert_preserves({v}) + assert_preserves(v) -- abuse v = vector.new(1, 2, 3) v.a = "bla" - assert.same({x = 1, y = 2, z = 3, a = "bla"}, - core.deserialize(core.serialize(v))) + assert_preserves(v) + end) + + it("handles keywords as keys", function() + assert_preserves({["and"] = "keyword", ["for"] = "keyword"}) + end) + + describe("fuzzing", function() + local atomics = {true, false, math.huge, -math.huge} -- no NaN or nil + local function atomic() + return atomics[math.random(1, #atomics)] + end + local function num() + local sign = math.random() < 0.5 and -1 or 1 + local val = math.random(0, 2^52) + local exp = math.random() < 0.5 and 1 or 2^(math.random(-120, 120)) + return sign * val * exp + end + local function charcodes(count) + if count == 0 then return end + return math.random(0, 0xFF), charcodes(count - 1) + end + local function str() + return string.char(charcodes(math.random(0, 100))) + end + local primitives = {atomic, num, str} + local function primitive() + return primitives[math.random(1, #primitives)]() + end + local function tab(max_actions) + local root = {} + local tables = {root} + local function random_table() + return tables[#tables == 1 and 1 or math.random(1, #tables)] -- luacheck: ignore + end + for _ = 1, math.random(1, max_actions) do + local tab = random_table() + local value + if math.random() < 0.5 then + if math.random() < 0.5 then + value = random_table() + else + value = {} + table.insert(tables, value) + end + else + value = primitive() + end + tab[math.random() < 0.5 and (#tab + 1) or primitive()] = value + end + return root + end + it("primitives work", function() + for _ = 1, 1e3 do + assert_preserves(primitive()) + end + end) + it("tables work", function() + for _ = 1, 100 do + local fuzzed_table = tab(1e3) + assert_same(fuzzed_table, table.copy(fuzzed_table)) + assert_preserves(fuzzed_table) + end + end) end) end) -- cgit v1.2.3