diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/qapi-types.py | 6 | ||||
-rw-r--r-- | scripts/qapi-visit.py | 4 | ||||
-rw-r--r-- | scripts/qapi.py | 88 |
3 files changed, 60 insertions, 38 deletions
diff --git a/scripts/qapi-types.py b/scripts/qapi-types.py index 2390887f28..c9e0201d10 100644 --- a/scripts/qapi-types.py +++ b/scripts/qapi-types.py @@ -170,7 +170,7 @@ typedef enum %(name)s return lookup_decl + enum_decl -def generate_anon_union_qtypes(expr): +def generate_alternate_qtypes(expr): name = expr['union'] members = expr['data'] @@ -181,7 +181,7 @@ const int %(name)s_qtypes[QTYPE_MAX] = { name=name) for key in members: - qtype = find_anonymous_member_qtype(members[key]) + qtype = find_alternate_member_qtype(members[key]) assert qtype, "Invalid anonymous union member" ret += mcgen(''' @@ -408,7 +408,7 @@ for expr in exprs: fdef.write(generate_enum_lookup('%sKind' % expr['union'], expr['data'].keys())) if expr.get('discriminator') == {}: - fdef.write(generate_anon_union_qtypes(expr)) + fdef.write(generate_alternate_qtypes(expr)) else: continue fdecl.write(ret) diff --git a/scripts/qapi-visit.py b/scripts/qapi-visit.py index dbf0101cba..6bd2b6bfab 100644 --- a/scripts/qapi-visit.py +++ b/scripts/qapi-visit.py @@ -237,7 +237,7 @@ void visit_type_%(name)s(Visitor *m, %(name)s *obj, const char *name, Error **er ''', name=name) -def generate_visit_anon_union(name, members): +def generate_visit_alternate(name, members): ret = mcgen(''' void visit_type_%(name)s(Visitor *m, %(name)s **obj, const char *name, Error **errp) @@ -302,7 +302,7 @@ def generate_visit_union(expr): if discriminator == {}: assert not base - return generate_visit_anon_union(name, members) + return generate_visit_alternate(name, members) enum_define = discriminator_find_enum_define(expr) if enum_define: diff --git a/scripts/qapi.py b/scripts/qapi.py index 0c3459bfe2..0b88325abd 100644 --- a/scripts/qapi.py +++ b/scripts/qapi.py @@ -224,21 +224,16 @@ def find_base_fields(base): return None return base_struct_define['data'] -# Return the qtype of an anonymous union branch, or None on error. -def find_anonymous_member_qtype(qapi_type): +# Return the qtype of an alternate branch, or None on error. +def find_alternate_member_qtype(qapi_type): if builtin_types.has_key(qapi_type): return builtin_types[qapi_type] elif find_struct(qapi_type): return "QTYPE_QDICT" elif find_enum(qapi_type): return "QTYPE_QSTRING" - else: - union = find_union(qapi_type) - if union: - discriminator = union.get('discriminator') - if discriminator == {}: - return None - return "QTYPE_QDICT" + elif find_union(qapi_type): + return "QTYPE_QDICT" return None # Return the discriminator enum define if discriminator is specified as an @@ -276,7 +271,6 @@ def check_union(expr, expr_info): discriminator = expr.get('discriminator') members = expr['data'] values = { 'MAX': '(automatic)' } - types_seen = {} # If the object has a member 'base', its value must name a complex type, # and there must be a discriminator. @@ -286,13 +280,15 @@ def check_union(expr, expr_info): "Union '%s' requires a discriminator to go " "along with base" %name) - # If the union object has no member 'discriminator', it's a - # simple union. If 'discriminator' is {}, it is an anonymous union. - if discriminator is None or discriminator == {}: + # Two types of unions, determined by discriminator. + assert discriminator != {} + + # With no discriminator it is a simple union. + if discriminator is None: enum_define = None if base is not None: raise QAPIExprError(expr_info, - "Union '%s' must not have a base" + "Simple union '%s' must not have a base" % name) # Else, it's a flat union. @@ -347,24 +343,46 @@ def check_union(expr, expr_info): % (name, key, values[c_key])) values[c_key] = key - # Ensure anonymous unions have no type conflicts. - if discriminator == {}: - if isinstance(value, list): - raise QAPIExprError(expr_info, - "Anonymous union '%s' member '%s' must " - "not be array type" % (name, key)) - qtype = find_anonymous_member_qtype(value) - if not qtype: - raise QAPIExprError(expr_info, - "Anonymous union '%s' member '%s' has " - "invalid type '%s'" % (name, key, value)) - if qtype in types_seen: - raise QAPIExprError(expr_info, - "Anonymous union '%s' member '%s' can't " - "be distinguished from member '%s'" - % (name, key, types_seen[qtype])) - types_seen[qtype] = key +def check_alternate(expr, expr_info): + name = expr['union'] + base = expr.get('base') + discriminator = expr.get('discriminator') + members = expr['data'] + values = { 'MAX': '(automatic)' } + types_seen = {} + + assert discriminator == {} + if base is not None: + raise QAPIExprError(expr_info, + "Anonymous union '%s' must not have a base" + % name) + + # Check every branch + for (key, value) in members.items(): + # Check for conflicts in the generated enum + c_key = _generate_enum_string(key) + if c_key in values: + raise QAPIExprError(expr_info, + "Anonymous union '%s' member '%s' clashes " + "with '%s'" % (name, key, values[c_key])) + values[c_key] = key + # Ensure alternates have no type conflicts. + if isinstance(value, list): + raise QAPIExprError(expr_info, + "Anonymous union '%s' member '%s' must " + "not be array type" % (name, key)) + qtype = find_alternate_member_qtype(value) + if not qtype: + raise QAPIExprError(expr_info, + "Anonymous union '%s' member '%s' has " + "invalid type '%s'" % (name, key, value)) + if qtype in types_seen: + raise QAPIExprError(expr_info, + "Anonymous union '%s' member '%s' can't " + "be distinguished from member '%s'" + % (name, key, types_seen[qtype])) + types_seen[qtype] = key def check_enum(expr, expr_info): name = expr['enum'] @@ -394,7 +412,10 @@ def check_exprs(schema): if expr.has_key('enum'): check_enum(expr, info) elif expr.has_key('union'): - check_union(expr, info) + if expr.get('discriminator') == {}: + check_alternate(expr, info) + else: + check_union(expr, info) elif expr.has_key('event'): check_event(expr, info) @@ -536,7 +557,8 @@ def find_struct(name): def add_union(definition): global union_types - union_types.append(definition) + if definition.get('discriminator') != {}: + union_types.append(definition) def find_union(name): global union_types |