diff --git a/.editorconfig b/.editorconfig index 3ab16b35c6d..feb81b8ccc5 100644 --- a/.editorconfig +++ b/.editorconfig @@ -14,6 +14,10 @@ tab_width = 8 indent_style = space indent_size = 2 +[{*.java,*.java.erb}] +indent_style = space +indent_size = 4 + [{*[Mm]akefile*,*.mak,*.mk,depend}] indent_style = tab indent_size = 4 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2b691ece7e5..1a7cdb1681a 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -52,6 +52,7 @@ jobs: bundle exec rake compile bundle exec rake check_annotations bundle exec rake typecheck:steep + rm lib/prism/node.rb && CHECK_FIELD_KIND=true bundle exec rake build: strategy: diff --git a/config.yml b/config.yml index a89052d7699..4357663d48e 100644 --- a/config.yml +++ b/config.yml @@ -1556,6 +1556,9 @@ nodes: ^ - name: elements type: node[] + kind: + - AssocNode + - AssocSplatNode comment: | The elements of the hash. These can be either `AssocNode`s or `AssocSplatNode`s. @@ -1903,6 +1906,10 @@ nodes: type: location - name: parts type: node[] + kind: + - StringNode + - EmbeddedStatementsNode + - EmbeddedVariableNode - name: closing_loc type: location newline: parts @@ -1920,6 +1927,10 @@ nodes: type: location - name: parts type: node[] + kind: + - StringNode + - EmbeddedStatementsNode + - EmbeddedVariableNode - name: closing_loc type: location newline: parts @@ -1934,6 +1945,11 @@ nodes: type: location? - name: parts type: node[] + kind: + - StringNode + - EmbeddedStatementsNode + - EmbeddedVariableNode + - InterpolatedStringNode # `"a" "#{b}"` - name: closing_loc type: location? newline: parts @@ -1948,6 +1964,10 @@ nodes: type: location? - name: parts type: node[] + kind: + - StringNode + - EmbeddedStatementsNode + - EmbeddedVariableNode - name: closing_loc type: location? newline: parts @@ -1962,6 +1982,10 @@ nodes: type: location - name: parts type: node[] + kind: + - StringNode + - EmbeddedStatementsNode + - EmbeddedVariableNode - name: closing_loc type: location newline: parts @@ -1983,6 +2007,9 @@ nodes: kind: KeywordHashNodeFlags - name: elements type: node[] + kind: + - AssocNode + - AssocSplatNode comment: | Represents a hash literal without opening and closing braces. @@ -2190,6 +2217,7 @@ nodes: kind: CallNode - name: targets type: node[] + kind: LocalVariableTargetNode comment: | Represents writing local variables using a regular expression match with named capture groups. @@ -2221,10 +2249,35 @@ nodes: fields: - name: lefts type: node[] + kind: + - LocalVariableTargetNode + - InstanceVariableTargetNode + - ClassVariableTargetNode + - GlobalVariableTargetNode + - ConstantTargetNode + - ConstantPathTargetNode + - CallTargetNode + - IndexTargetNode + - MultiTargetNode + - RequiredParameterNode + - BackReferenceReadNode # On parsing error of `$',` + - NumberedReferenceReadNode # On parsing error of `$1,` - name: rest type: node? - name: rights type: node[] + kind: + - LocalVariableTargetNode + - InstanceVariableTargetNode + - ClassVariableTargetNode + - GlobalVariableTargetNode + - ConstantTargetNode + - ConstantPathTargetNode + - CallTargetNode + - IndexTargetNode + - MultiTargetNode + - RequiredParameterNode + - BackReferenceReadNode # On parsing error of `*,$'` - name: lparen_loc type: location? - name: rparen_loc @@ -2238,10 +2291,30 @@ nodes: fields: - name: lefts type: node[] + kind: + - LocalVariableTargetNode + - InstanceVariableTargetNode + - ClassVariableTargetNode + - GlobalVariableTargetNode + - ConstantTargetNode + - ConstantPathTargetNode + - CallTargetNode + - IndexTargetNode + - MultiTargetNode - name: rest type: node? - name: rights type: node[] + kind: + - LocalVariableTargetNode + - InstanceVariableTargetNode + - ClassVariableTargetNode + - GlobalVariableTargetNode + - ConstantTargetNode + - ConstantPathTargetNode + - CallTargetNode + - IndexTargetNode + - MultiTargetNode - name: lparen_loc type: location? - name: rparen_loc @@ -2820,6 +2893,9 @@ nodes: fields: - name: names type: node[] + kind: + - SymbolNode + - InterpolatedSymbolNode - name: keyword_loc type: location comment: | diff --git a/templates/java/org/prism/Loader.java.erb b/templates/java/org/prism/Loader.java.erb index b2327a7a140..0034a6db712 100644 --- a/templates/java/org/prism/Loader.java.erb +++ b/templates/java/org/prism/Loader.java.erb @@ -257,18 +257,6 @@ public class Loader { return constants; } - private Nodes.Node[] loadNodes() { - int length = loadVarUInt(); - if (length == 0) { - return Nodes.Node.EMPTY_ARRAY; - } - Nodes.Node[] nodes = new Nodes.Node[length]; - for (int i = 0; i < length; i++) { - nodes[i] = loadNode(); - } - return nodes; - } - private Nodes.Location loadLocation() { return new Nodes.Location(loadVarUInt(), loadVarUInt()); } @@ -367,6 +355,7 @@ public class Loader { int length = loadVarUInt(); switch (type) { + <%- array_types = [] -%> <%- nodes.each_with_index do |node, index| -%> case <%= index + 1 %>: <%- @@ -376,7 +365,10 @@ public class Loader { when Prism::NodeField then "#{field.java_cast}loadNode()" when Prism::OptionalNodeField then "#{field.java_cast}loadOptionalNode()" when Prism::StringField then "loadString()" - when Prism::NodeListField then "loadNodes()" + when Prism::NodeListField then + element_type = field.java_type.sub('[]', '') + array_types << element_type + "load#{element_type}s()" when Prism::ConstantField then "loadConstant()" when Prism::OptionalConstantField then "loadOptionalConstant()" when Prism::ConstantListField then "loadConstants()" @@ -398,6 +390,26 @@ public class Loader { throw new Error("Unknown node type: " + type); } } + <%- array_types.uniq.each do |type| -%> + + private static final Nodes.<%= type %>[] EMPTY_<%= type %>_ARRAY = {}; + + private Nodes.<%= type %>[] load<%= type %>s() { + int length = loadVarUInt(); + if (length == 0) { + return EMPTY_<%= type %>_ARRAY; + } + Nodes.<%= type %>[] nodes = new Nodes.<%= type %>[length]; + for (int i = 0; i < length; i++) { + <%- if type == 'Node' -%> + nodes[i] = loadNode(); + <%- else -%> + nodes[i] = (Nodes.<%= type %>) loadNode(); + <%- end -%> + } + return nodes; + } + <%- end -%> private void expect(byte value, String error) { byte b = buffer.get(); diff --git a/templates/java/org/prism/Nodes.java.erb b/templates/java/org/prism/Nodes.java.erb index 2b5e1afa725..69e6986e9f2 100644 --- a/templates/java/org/prism/Nodes.java.erb +++ b/templates/java/org/prism/Nodes.java.erb @@ -23,6 +23,12 @@ public abstract class Nodes { public @interface Nullable { } + @Target(ElementType.FIELD) + @Retention(RetentionPolicy.SOURCE) + public @interface UnionType { + Class[] value(); + } + public static final class Location { public static final Location[] EMPTY_ARRAY = {}; @@ -207,6 +213,9 @@ public abstract class Nodes { <%- if field.class.name.include?('Optional') -%> @Nullable <%- end -%> + <%- if field.respond_to?(:union_kind) && field.union_kind -%> + @UnionType({ <%= field.union_kind.map { |t| "#{t}.class" }.join(', ') %> }) + <%- end -%> public final <%= field.java_type %> <%= field.name %>; <%- end -%> diff --git a/templates/lib/prism/node.rb.erb b/templates/lib/prism/node.rb.erb index 32134f88208..4762963bf6f 100644 --- a/templates/lib/prism/node.rb.erb +++ b/templates/lib/prism/node.rb.erb @@ -110,6 +110,9 @@ module Prism @newline = false @location = location <%- node.fields.each do |field| -%> + <%- if Prism::CHECK_FIELD_KIND && field.respond_to?(:check_field_kind) -%> + raise <%= field.name %>.inspect unless <%= field.check_field_kind %> + <%- end -%> @<%= field.name %> = <%= field.name %> <%- end -%> end diff --git a/templates/template.rb b/templates/template.rb index d981b895926..fd55d5228ba 100755 --- a/templates/template.rb +++ b/templates/template.rb @@ -6,6 +6,7 @@ module Prism SERIALIZE_ONLY_SEMANTICS_FIELDS = ENV.fetch("PRISM_SERIALIZE_ONLY_SEMANTICS_FIELDS", false) + CHECK_FIELD_KIND = ENV.fetch("CHECK_FIELD_KIND", false) JAVA_BACKEND = ENV["PRISM_JAVA_BACKEND"] || "truffleruby" JAVA_STRING_TYPE = JAVA_BACKEND == "jruby" ? "org.jruby.RubySymbol" : "String" @@ -123,6 +124,14 @@ def rbs_class def rbi_class "Prism::#{ruby_type}" end + + def check_field_kind + if union_kind + "[#{union_kind.join(', ')}].include?(#{name}.class)" + else + "#{name}.is_a?(#{ruby_type})" + end + end end # This represents a field on a node that is itself a node and can be @@ -141,11 +150,19 @@ def rbs_class def rbi_class "T.nilable(Prism::#{ruby_type})" end + + def check_field_kind + if union_kind + "[#{union_kind.join(', ')}, NilClass].include?(#{name}.class)" + else + "#{name}.nil? || #{name}.is_a?(#{ruby_type})" + end + end end # This represents a field on a node that is a list of nodes. We pass them as # references and store them directly on the struct. - class NodeListField < Field + class NodeListField < NodeKindField def rbs_class if specific_kind "Array[#{specific_kind}]" @@ -157,20 +174,19 @@ def rbs_class end def rbi_class - "T::Array[Prism::Node]" + "T::Array[Prism::#{ruby_type}]" end def java_type - "Node[]" + "#{super}[]" end - # TODO: unduplicate with NodeKindField - def specific_kind - options[:kind] unless options[:kind].is_a?(Array) - end - - def union_kind - options[:kind] if options[:kind].is_a?(Array) + def check_field_kind + if union_kind + "#{name}.all? { |n| [#{union_kind.join(', ')}].include?(n.class) }" + else + "#{name}.all? { |n| n.is_a?(#{ruby_type}) }" + end end end