Skip to content

Commit

Permalink
Fix parsing for typed function reference types
Browse files Browse the repository at this point in the history
Previously an indexed ref type was parsed as a `Var`,
assuming that regular types would be represented as
index vars and named ref types would be named vars.

Unfortunately, it's also possible for a ref type to
be an indexed var that overlaps with a type (e.g.,
the type "any" which is 0x0 and the index 0).

This commit restructures the code so that a `Type`
is returned instead. In order the look up the index
of the parsed type in the type section, the `Module`
that is being generated will be tracked by the parser.
  • Loading branch information
takikawa committed Apr 4, 2022
1 parent 27c5d11 commit fcc7e47
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 107 deletions.
7 changes: 0 additions & 7 deletions src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,6 @@ struct FuncSignature {
TypeVector param_types;
TypeVector result_types;

// Some types can have names, for example (ref $foo) has type $foo.
// So to use this type we need to translate its name into
// a proper index from the module type section.
// This is the mapping from parameter/result index to its name.
std::unordered_map<uint32_t, std::string> param_type_names;
std::unordered_map<uint32_t, std::string> result_type_names;

Index GetNumParams() const { return param_types.size(); }
Index GetNumResults() const { return result_types.size(); }
Type GetParamType(Index index) const { return param_types[index]; }
Expand Down
134 changes: 44 additions & 90 deletions src/wast-parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,39 +268,6 @@ bool ResolveFuncTypeWithEmptySignature(const Module& module,
return false;
}

void ResolveTypeName(
const Module& module,
Type& type,
Index index,
const std::unordered_map<uint32_t, std::string>& bindings) {
if (type != Type::Reference || type.GetReferenceIndex() != kInvalidIndex) {
return;
}

const auto name_iterator = bindings.find(index);
assert(name_iterator != bindings.cend());
const auto type_index = module.type_bindings.FindIndex(name_iterator->second);
assert(type_index != kInvalidIndex);
type = Type(Type::Reference, type_index);
}

void ResolveTypeNames(const Module& module, FuncDeclaration* decl) {
assert(decl);
auto& signature = decl->sig;

for (uint32_t param_index = 0; param_index < signature.GetNumParams();
++param_index) {
ResolveTypeName(module, signature.param_types[param_index], param_index,
signature.param_type_names);
}

for (uint32_t result_index = 0; result_index < signature.GetNumResults();
++result_index) {
ResolveTypeName(module, signature.result_types[result_index], result_index,
signature.result_type_names);
}
}

void ResolveImplicitlyDefinedFunctionType(const Location& loc,
Module* module,
const FuncDeclaration& decl) {
Expand Down Expand Up @@ -406,7 +373,6 @@ class ResolveFuncTypesExprVisitorDelegate : public ExprVisitor::DelegateNop {
: module_(module), errors_(errors) {}

void ResolveBlockDeclaration(const Location& loc, BlockDeclaration* decl) {
ResolveTypeNames(*module_, decl);
ResolveFuncTypeWithEmptySignature(*module_, decl);
if (!IsInlinableFuncSignature(decl->sig)) {
ResolveImplicitlyDefinedFunctionType(loc, module_, *decl);
Expand Down Expand Up @@ -485,7 +451,6 @@ Result ResolveFuncTypes(Module* module, Errors* errors) {
bool has_func_type_and_empty_signature = false;

if (decl) {
ResolveTypeNames(*module, decl);
has_func_type_and_empty_signature =
ResolveFuncTypeWithEmptySignature(*module, decl);
ResolveImplicitlyDefinedFunctionType(field.loc, module, *decl);
Expand Down Expand Up @@ -710,6 +675,13 @@ Result WastParser::ErrorIfLpar(const std::vector<std::string>& expected,
return Result::Ok;
}

Index WastParser::LookupTypeName(Var name) {
assert(module_);
Index index = module_->GetFuncTypeIndex(name);
assert(index != kInvalidIndex);
return index;
}

bool WastParser::ParseBindVarOpt(std::string* name) {
WABT_TRACE(ParseBindVarOpt);
if (!PeekMatch(TokenType::Var)) {
Expand Down Expand Up @@ -853,7 +825,7 @@ bool WastParser::ParseElemExprVarListOpt(ExprListVector* out_list) {
return !out_list->empty();
}

Result WastParser::ParseValueType(Var* out_type) {
Result WastParser::ParseValueType(Type* out_type) {
WABT_TRACE(ParseValueType);

const bool is_ref_type = PeekMatchRefType();
Expand All @@ -866,8 +838,15 @@ Result WastParser::ParseValueType(Var* out_type) {
if (is_ref_type) {
EXPECT(Lpar);
EXPECT(Ref);
CHECK_RESULT(ParseVar(out_type));
Var name;
CHECK_RESULT(ParseVar(&name));
EXPECT(Rpar);
if (name.is_index()) {
*out_type = Type(Type::Reference, name.index());
} else {
Index type_index = LookupTypeName(name);
*out_type = Type(Type::Reference, type_index);
}
return Result::Ok;
}

Expand All @@ -892,30 +871,20 @@ Result WastParser::ParseValueType(Var* out_type) {
return Result::Error;
}

*out_type = Var(type);
*out_type = type;
return Result::Ok;
}

Result WastParser::ParseValueTypeList(
TypeVector* out_type_list,
std::unordered_map<uint32_t, std::string>* type_names) {
Result WastParser::ParseValueTypeList(TypeVector* out_type_list) {
WABT_TRACE(ParseValueTypeList);
while (true) {
if (!PeekMatchRefType() && !PeekMatch(TokenType::ValueType)) {
break;
}

Var type;
Type type;
CHECK_RESULT(ParseValueType(&type));

if (type.is_index()) {
out_type_list->push_back(Type(type.index()));
} else {
assert(type.is_name());
assert(options_->features.function_references_enabled());
type_names->emplace(out_type_list->size(), type.name());
out_type_list->push_back(Type(Type::Reference, kInvalidIndex));
}
out_type_list->push_back(type);
}

return Result::Ok;
Expand Down Expand Up @@ -1108,6 +1077,7 @@ Result WastParser::ParseNat(uint64_t* out_nat, bool is_64) {
Result WastParser::ParseModule(std::unique_ptr<Module>* out_module) {
WABT_TRACE(ParseModule);
auto module = MakeUnique<Module>();
auto scope = ModuleScope(this, module.get());

if (PeekMatchLpar(TokenType::Module)) {
// Starts with "(module". Allow text and binary modules, but no quoted
Expand Down Expand Up @@ -1143,6 +1113,7 @@ Result WastParser::ParseScript(std::unique_ptr<Script>* out_script) {
if (IsModuleField(PeekPair())) {
// Parse an inline module (i.e. one with no surrounding (module)).
auto command = MakeUnique<ModuleCommand>();
ModuleScope(this, &command->module);
command->module.loc = GetLocation();
CHECK_RESULT(ParseModuleFieldList(&command->module));
script->commands.emplace_back(std::move(command));
Expand Down Expand Up @@ -1365,9 +1336,8 @@ Result WastParser::ParseFuncModuleField(Module* module) {
CHECK_RESULT(ParseTypeUseOpt(&func.decl));
CHECK_RESULT(ParseFuncSignature(&func.decl.sig, &func.bindings));
TypeVector local_types;
CHECK_RESULT(ParseBoundValueTypeList(
TokenType::Local, &local_types, &func.bindings,
&func.decl.sig.param_type_names, func.GetNumParams()));
CHECK_RESULT(ParseBoundValueTypeList(TokenType::Local, &local_types,
&func.bindings, func.GetNumParams()));
func.local_types.Set(local_types);
CHECK_RESULT(ParseTerminatingInstrList(&func.exprs));
module->AppendField(std::move(field));
Expand Down Expand Up @@ -1427,15 +1397,11 @@ Result WastParser::ParseField(Field* field) {
// TODO: Share with ParseGlobalType?
if (MatchLpar(TokenType::Mut)) {
field->mutable_ = true;
Var type;
CHECK_RESULT(ParseValueType(&type));
field->type = Type(type.index());
CHECK_RESULT(ParseValueType(&field->type));
EXPECT(Rpar);
} else {
field->mutable_ = false;
Var type;
CHECK_RESULT(ParseValueType(&type));
field->type = Type(type.index());
CHECK_RESULT(ParseValueType(&field->type));
}
return Result::Ok;
};
Expand Down Expand Up @@ -1770,68 +1736,55 @@ Result WastParser::ParseFuncSignature(FuncSignature* sig,
BindingHash* param_bindings) {
WABT_TRACE(ParseFuncSignature);
CHECK_RESULT(ParseBoundValueTypeList(TokenType::Param, &sig->param_types,
param_bindings, &sig->param_type_names));
CHECK_RESULT(ParseResultList(&sig->result_types, &sig->result_type_names));
param_bindings));
CHECK_RESULT(ParseResultList(&sig->result_types));
return Result::Ok;
}

Result WastParser::ParseUnboundFuncSignature(FuncSignature* sig) {
WABT_TRACE(ParseUnboundFuncSignature);
CHECK_RESULT(ParseUnboundValueTypeList(TokenType::Param, &sig->param_types,
&sig->param_type_names));
CHECK_RESULT(ParseResultList(&sig->result_types, &sig->result_type_names));
CHECK_RESULT(ParseUnboundValueTypeList(TokenType::Param, &sig->param_types));
CHECK_RESULT(ParseResultList(&sig->result_types));
return Result::Ok;
}

Result WastParser::ParseBoundValueTypeList(
TokenType token,
TypeVector* types,
BindingHash* bindings,
std::unordered_map<uint32_t, std::string>* type_names,
Index binding_index_offset) {
WABT_TRACE(ParseBoundValueTypeList);
while (MatchLpar(token)) {
if (PeekMatch(TokenType::Var)) {
std::string name;
Var type;
Type type;
Location loc = GetLocation();
ParseBindVarOpt(&name);
CHECK_RESULT(ParseValueType(&type));
bindings->emplace(name,
Binding(loc, binding_index_offset + types->size()));
if (type.is_index()) {
types->push_back(Type(type.index()));
} else {
assert(type.is_name());
assert(options_->features.function_references_enabled());
type_names->emplace(binding_index_offset + types->size(), type.name());
types->push_back(Type(Type::Reference, kInvalidIndex));
}
types->push_back(type);
} else {
CHECK_RESULT(ParseValueTypeList(types, type_names));
CHECK_RESULT(ParseValueTypeList(types));
}
EXPECT(Rpar);
}
return Result::Ok;
}

Result WastParser::ParseUnboundValueTypeList(
TokenType token,
TypeVector* types,
std::unordered_map<uint32_t, std::string>* type_names) {
Result WastParser::ParseUnboundValueTypeList(TokenType token,
TypeVector* types) {
WABT_TRACE(ParseUnboundValueTypeList);
while (MatchLpar(token)) {
CHECK_RESULT(ParseValueTypeList(types, type_names));
CHECK_RESULT(ParseValueTypeList(types));
EXPECT(Rpar);
}
return Result::Ok;
}

Result WastParser::ParseResultList(
TypeVector* result_types,
std::unordered_map<uint32_t, std::string>* type_names) {
Result WastParser::ParseResultList(TypeVector* result_types) {
WABT_TRACE(ParseResultList);
return ParseUnboundValueTypeList(TokenType::Result, result_types, type_names);
return ParseUnboundValueTypeList(TokenType::Result, result_types);
}

Result WastParser::ParseInstrList(ExprList* exprs) {
Expand Down Expand Up @@ -2082,7 +2035,7 @@ Result WastParser::ParsePlainInstr(std::unique_ptr<Expr>* out_expr) {
TypeVector result;
if (options_->features.reference_types_enabled() &&
MatchLpar(TokenType::Result)) {
CHECK_RESULT(ParseValueTypeList(&result, nullptr));
CHECK_RESULT(ParseValueTypeList(&result));
EXPECT(Rpar);
}
out_expr->reset(new SelectExpr(result, loc));
Expand Down Expand Up @@ -3116,15 +3069,15 @@ Result WastParser::ParseGlobalType(Global* global) {
WABT_TRACE(ParseGlobalType);
if (MatchLpar(TokenType::Mut)) {
global->mutable_ = true;
Var type;
Type type;
CHECK_RESULT(ParseValueType(&type));
global->type = Type(type.index());
global->type = type;
CHECK_RESULT(ErrorIfLpar({"i32", "i64", "f32", "f64"}));
EXPECT(Rpar);
} else {
Var type;
Type type;
CHECK_RESULT(ParseValueType(&type));
global->type = Type(type.index());
global->type = type;
}

return Result::Ok;
Expand Down Expand Up @@ -3440,6 +3393,7 @@ Result WastParser::ParseScriptModule(

default: {
auto tsm = MakeUnique<TextScriptModule>();
auto scope = ModuleScope(this, &tsm->module);
tsm->module.name = name;
tsm->module.loc = loc;
if (IsModuleField(PeekPair())) {
Expand Down
35 changes: 25 additions & 10 deletions src/wast-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ class WastParser {
Expectation,
};

// RAII class for maintaining a pointer to the current module from the parser.
class ModuleScope {
public:
ModuleScope(WastParser* parser, Module* module)
: parser(parser), old_module(parser->module_) {
parser->module_ = module;
}

~ModuleScope() { parser->module_ = old_module; }

WastParser* parser;
Module* old_module;
};

void ErrorUnlessOpcodeEnabled(const Token&);

// Print an error message listing the expected tokens, as well as an example
Expand Down Expand Up @@ -122,6 +136,9 @@ class WastParser {
// synchronized.
Result Synchronize(SynchronizeFunc);

// Look up a type name to index mapping.
Index LookupTypeName(Var name);

bool ParseBindVarOpt(std::string* name);
Result ParseVar(Var* out_var);
bool ParseVarOpt(Var* out_var, Var default_var = Var());
Expand All @@ -133,10 +150,8 @@ class WastParser {
bool ParseElemExprOpt(ExprList* out_elem_expr);
bool ParseElemExprListOpt(ExprListVector* out_list);
bool ParseElemExprVarListOpt(ExprListVector* out_list);
Result ParseValueType(Var* out_type);
Result ParseValueTypeList(
TypeVector* out_type_list,
std::unordered_map<uint32_t, std::string>* type_names);
Result ParseValueType(Type* out_type);
Result ParseValueTypeList(TypeVector* out_type_list);
Result ParseRefKind(Type* out_type);
Result ParseRefType(Type* out_type);
bool ParseRefTypeOpt(Type* out_type);
Expand Down Expand Up @@ -171,13 +186,9 @@ class WastParser {
Result ParseBoundValueTypeList(TokenType,
TypeVector*,
BindingHash*,
std::unordered_map<uint32_t, std::string>*,
Index binding_index_offset = 0);
Result ParseUnboundValueTypeList(TokenType,
TypeVector*,
std::unordered_map<uint32_t, std::string>*);
Result ParseResultList(TypeVector*,
std::unordered_map<uint32_t, std::string>*);
Result ParseUnboundValueTypeList(TokenType, TypeVector*);
Result ParseResultList(TypeVector*);
Result ParseInstrList(ExprList*);
Result ParseTerminatingInstrList(ExprList*);
Result ParseInstr(ExprList*);
Expand Down Expand Up @@ -257,6 +268,10 @@ class WastParser {
Errors* errors_;
WastParseOptions* options_;

// Used to look up global state such as type names. Use the RAII class
// ModuleScope to ensure the pointer points to the right module.
Module* module_ = nullptr;

CircularArray<Token, 2> tokens_;
};

Expand Down
Loading

0 comments on commit fcc7e47

Please sign in to comment.