Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give a kind to all node[] fields and better types in Java-generated code #2497

Merged
merged 8 commits into from
Feb 28, 2024
4 changes: 4 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ tab_width = 8
indent_style = space
indent_size = 2

[{*.java,*.java.erb}]
indent_style = space
indent_size = 4

[{*[Mm]akefile*,*.mak,*.mk,depend}]
indent_style = tab
indent_size = 4
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
bundle exec rake compile
bundle exec rake check_annotations
bundle exec rake typecheck:steep
rm lib/prism/node.rb && CHECK_FIELD_KIND=true bundle exec rake

build:
strategy:
Expand Down
76 changes: 76 additions & 0 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,9 @@ nodes:
^
- name: elements
type: node[]
kind:
- AssocNode
- AssocSplatNode
comment: |
The elements of the hash. These can be either `AssocNode`s or `AssocSplatNode`s.

Expand Down Expand Up @@ -1903,6 +1906,10 @@ nodes:
type: location
- name: parts
type: node[]
kind:
- StringNode
- EmbeddedStatementsNode
- EmbeddedVariableNode
- name: closing_loc
type: location
newline: parts
Expand All @@ -1920,6 +1927,10 @@ nodes:
type: location
- name: parts
type: node[]
kind:
- StringNode
- EmbeddedStatementsNode
- EmbeddedVariableNode
- name: closing_loc
type: location
newline: parts
Expand All @@ -1934,6 +1945,11 @@ nodes:
type: location?
- name: parts
type: node[]
kind:
- StringNode
- EmbeddedStatementsNode
- EmbeddedVariableNode
- InterpolatedStringNode # `"a" "#{b}"`
- name: closing_loc
type: location?
newline: parts
Expand All @@ -1948,6 +1964,10 @@ nodes:
type: location?
- name: parts
type: node[]
kind:
- StringNode
- EmbeddedStatementsNode
- EmbeddedVariableNode
- name: closing_loc
type: location?
newline: parts
Expand All @@ -1962,6 +1982,10 @@ nodes:
type: location
- name: parts
type: node[]
kind:
- StringNode
- EmbeddedStatementsNode
- EmbeddedVariableNode
- name: closing_loc
type: location
newline: parts
Expand All @@ -1983,6 +2007,9 @@ nodes:
kind: KeywordHashNodeFlags
- name: elements
type: node[]
kind:
- AssocNode
- AssocSplatNode
comment: |
Represents a hash literal without opening and closing braces.

Expand Down Expand Up @@ -2190,6 +2217,7 @@ nodes:
kind: CallNode
- name: targets
type: node[]
kind: LocalVariableTargetNode
comment: |
Represents writing local variables using a regular expression match with named capture groups.

Expand Down Expand Up @@ -2221,10 +2249,35 @@ nodes:
fields:
- name: lefts
type: node[]
kind:
- LocalVariableTargetNode
- InstanceVariableTargetNode
- ClassVariableTargetNode
- GlobalVariableTargetNode
- ConstantTargetNode
- ConstantPathTargetNode
- CallTargetNode
- IndexTargetNode
- MultiTargetNode
- RequiredParameterNode
- BackReferenceReadNode # On parsing error of `$',`
- NumberedReferenceReadNode # On parsing error of `$1,`
- name: rest
type: node?
- name: rights
type: node[]
kind:
- LocalVariableTargetNode
- InstanceVariableTargetNode
- ClassVariableTargetNode
- GlobalVariableTargetNode
- ConstantTargetNode
- ConstantPathTargetNode
- CallTargetNode
- IndexTargetNode
- MultiTargetNode
- RequiredParameterNode
- BackReferenceReadNode # On parsing error of `*,$'`
- name: lparen_loc
type: location?
- name: rparen_loc
Expand All @@ -2238,10 +2291,30 @@ nodes:
fields:
- name: lefts
type: node[]
kind:
- LocalVariableTargetNode
- InstanceVariableTargetNode
- ClassVariableTargetNode
- GlobalVariableTargetNode
- ConstantTargetNode
- ConstantPathTargetNode
- CallTargetNode
- IndexTargetNode
- MultiTargetNode
- name: rest
type: node?
- name: rights
type: node[]
kind:
- LocalVariableTargetNode
- InstanceVariableTargetNode
- ClassVariableTargetNode
- GlobalVariableTargetNode
- ConstantTargetNode
- ConstantPathTargetNode
- CallTargetNode
- IndexTargetNode
- MultiTargetNode
- name: lparen_loc
type: location?
- name: rparen_loc
Expand Down Expand Up @@ -2820,6 +2893,9 @@ nodes:
fields:
- name: names
type: node[]
kind:
- SymbolNode
- InterpolatedSymbolNode
- name: keyword_loc
type: location
comment: |
Expand Down
38 changes: 25 additions & 13 deletions templates/java/org/prism/Loader.java.erb
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,6 @@ public class Loader {
return constants;
}

private Nodes.Node[] loadNodes() {
int length = loadVarUInt();
if (length == 0) {
return Nodes.Node.EMPTY_ARRAY;
}
Nodes.Node[] nodes = new Nodes.Node[length];
for (int i = 0; i < length; i++) {
nodes[i] = loadNode();
}
return nodes;
}

private Nodes.Location loadLocation() {
return new Nodes.Location(loadVarUInt(), loadVarUInt());
}
Expand Down Expand Up @@ -367,6 +355,7 @@ public class Loader {
int length = loadVarUInt();

switch (type) {
<%- array_types = [] -%>
<%- nodes.each_with_index do |node, index| -%>
case <%= index + 1 %>:
<%-
Expand All @@ -376,7 +365,10 @@ public class Loader {
when Prism::NodeField then "#{field.java_cast}loadNode()"
when Prism::OptionalNodeField then "#{field.java_cast}loadOptionalNode()"
when Prism::StringField then "loadString()"
when Prism::NodeListField then "loadNodes()"
when Prism::NodeListField then
element_type = field.java_type.sub('[]', '')
array_types << element_type
"load#{element_type}s()"
when Prism::ConstantField then "loadConstant()"
when Prism::OptionalConstantField then "loadOptionalConstant()"
when Prism::ConstantListField then "loadConstants()"
Expand All @@ -398,6 +390,26 @@ public class Loader {
throw new Error("Unknown node type: " + type);
}
}
<%- array_types.uniq.each do |type| -%>

private static final Nodes.<%= type %>[] EMPTY_<%= type %>_ARRAY = {};

private Nodes.<%= type %>[] load<%= type %>s() {
int length = loadVarUInt();
if (length == 0) {
return EMPTY_<%= type %>_ARRAY;
}
Nodes.<%= type %>[] nodes = new Nodes.<%= type %>[length];
for (int i = 0; i < length; i++) {
<%- if type == 'Node' -%>
nodes[i] = loadNode();
<%- else -%>
nodes[i] = (Nodes.<%= type %>) loadNode();
<%- end -%>
}
return nodes;
}
<%- end -%>

private void expect(byte value, String error) {
byte b = buffer.get();
Expand Down
9 changes: 9 additions & 0 deletions templates/java/org/prism/Nodes.java.erb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ public abstract class Nodes {
public @interface Nullable {
}

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.SOURCE)
public @interface UnionType {
Class<? extends Node>[] value();
}

public static final class Location {

public static final Location[] EMPTY_ARRAY = {};
Expand Down Expand Up @@ -207,6 +213,9 @@ public abstract class Nodes {
<%- if field.class.name.include?('Optional') -%>
@Nullable
<%- end -%>
<%- if field.respond_to?(:union_kind) && field.union_kind -%>
@UnionType({ <%= field.union_kind.map { |t| "#{t}.class" }.join(', ') %> })
<%- end -%>
public final <%= field.java_type %> <%= field.name %>;
<%- end -%>

Expand Down
3 changes: 3 additions & 0 deletions templates/lib/prism/node.rb.erb
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ module Prism
@newline = false
@location = location
<%- node.fields.each do |field| -%>
<%- if Prism::CHECK_FIELD_KIND && field.respond_to?(:check_field_kind) -%>
raise <%= field.name %>.inspect unless <%= field.check_field_kind %>
<%- end -%>
@<%= field.name %> = <%= field.name %>
<%- end -%>
end
Expand Down
36 changes: 26 additions & 10 deletions templates/template.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

module Prism
SERIALIZE_ONLY_SEMANTICS_FIELDS = ENV.fetch("PRISM_SERIALIZE_ONLY_SEMANTICS_FIELDS", false)
CHECK_FIELD_KIND = ENV.fetch("CHECK_FIELD_KIND", false)

JAVA_BACKEND = ENV["PRISM_JAVA_BACKEND"] || "truffleruby"
JAVA_STRING_TYPE = JAVA_BACKEND == "jruby" ? "org.jruby.RubySymbol" : "String"
Expand Down Expand Up @@ -123,6 +124,14 @@ def rbs_class
def rbi_class
"Prism::#{ruby_type}"
end

def check_field_kind
if union_kind
"[#{union_kind.join(', ')}].include?(#{name}.class)"
else
"#{name}.is_a?(#{ruby_type})"
end
end
end

# This represents a field on a node that is itself a node and can be
Expand All @@ -141,11 +150,19 @@ def rbs_class
def rbi_class
"T.nilable(Prism::#{ruby_type})"
end

def check_field_kind
if union_kind
"[#{union_kind.join(', ')}, NilClass].include?(#{name}.class)"
else
"#{name}.nil? || #{name}.is_a?(#{ruby_type})"
end
end
end

# This represents a field on a node that is a list of nodes. We pass them as
# references and store them directly on the struct.
class NodeListField < Field
class NodeListField < NodeKindField
def rbs_class
if specific_kind
"Array[#{specific_kind}]"
Expand All @@ -157,20 +174,19 @@ def rbs_class
end

def rbi_class
"T::Array[Prism::Node]"
"T::Array[Prism::#{ruby_type}]"
end

def java_type
"Node[]"
"#{super}[]"
end

# TODO: unduplicate with NodeKindField
def specific_kind
options[:kind] unless options[:kind].is_a?(Array)
end

def union_kind
options[:kind] if options[:kind].is_a?(Array)
def check_field_kind
if union_kind
"#{name}.all? { |n| [#{union_kind.join(', ')}].include?(n.class) }"
else
"#{name}.all? { |n| n.is_a?(#{ruby_type}) }"
end
end
end

Expand Down
Loading