Skip to content

Commit 1b6513e

Browse files
authored
Inject Product super to generated case classes (#296)
1 parent c562aac commit 1b6513e

File tree

3 files changed

+33
-34
lines changed

3 files changed

+33
-34
lines changed

src/main/scala/com/spotify/scio/AnnotationTypeInjector.scala

+10-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package com.spotify.scio
1919

2020
import java.io.File
2121
import java.nio.charset.Charset
22-
import java.nio.file.{Paths, Files as JFiles}
22+
import java.nio.file.{Files as JFiles, Paths}
2323
import com.google.common.base.Charsets
2424
import com.google.common.hash.Hashing
2525
import com.google.common.io.Files
@@ -34,6 +34,15 @@ import scala.collection.mutable
3434

3535
object AnnotationTypeInjector {
3636
private val Log = Logger.getInstance(classOf[AnnotationTypeInjector])
37+
38+
// case classes implement Product trait
39+
val CaseClassSuper: String = "_root_.scala.Product"
40+
val CaseClassFunctions: Seq[String] = Seq(
41+
"def productArity: _root_.scala.Int = ???",
42+
"def productElement(n: _root_.scala.Int): _root_.scala.Any = ???",
43+
"def canEqual(x: _root_.scala.Any): _root_.scala.Boolean = ???"
44+
)
45+
3746
private val CaseClassArgs = """case\s+class\s+[^(]+\((.*)\).*""".r
3847
private val TypeArg = """[a-zA-Z0-9_$]+\s*:\s*[a-zA-Z0-9._$]+([\[(](.*?)[)\]]+)?""".r
3948
private val AlertEveryMissedXInvocations = 5

src/main/scala/com/spotify/scio/AvroTypeInjector.scala

+11-15
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ object AvroTypeInjector {
2626
s"$AvroTNamespace.fromPath",
2727
s"$AvroTNamespace.toSchema"
2828
)
29-
private val CaseClassSuper =
29+
30+
private val HasAvroAnnotationSuper =
3031
"_root_.com.spotify.scio.avro.types.AvroType.HasAvroAnnotation"
3132

3233
private def avroAnnotation(sc: ScClass): Option[String] =
@@ -44,25 +45,20 @@ final class AvroTypeInjector extends AnnotationTypeInjector {
4445
override def injectFunctions(source: ScTypeDefinition): Seq[String] =
4546
source match {
4647
case c: ScClass if avroAnnotation(c).isDefined =>
47-
val result = for {
48-
cc <- Option(c.containingClass)
49-
qn <- Option(cc.getQualifiedName)
48+
val fields = for {
49+
cc <- Option(c.containingClass).toSeq
50+
qn <- Option(cc.getQualifiedName).toSeq
5051
parent = qn.init
51-
defs <- {
52-
generatedCaseClasses(parent, c)
53-
.find(_.contains(CaseClassSuper))
54-
.map(getApplyPropsSignature)
55-
.map(v => s"def $v = ???")
56-
}
57-
} yield defs
58-
59-
result.toSeq
52+
cls <- generatedCaseClasses(parent, c).find(_.contains(HasAvroAnnotationSuper)).toSeq
53+
v <- getApplyPropsSignature(cls)
54+
} yield s"def $v = ???"
55+
CaseClassFunctions ++ fields
6056
case _ => Seq.empty
6157
}
6258

6359
override def injectSupers(source: ScTypeDefinition): Seq[String] =
6460
source match {
65-
case c: ScClass if avroAnnotation(c).isDefined => Seq(CaseClassSuper)
61+
case c: ScClass if avroAnnotation(c).isDefined => Seq(CaseClassSuper, HasAvroAnnotationSuper)
6662
case _ => Seq.empty
6763
}
6864

@@ -77,7 +73,7 @@ final class AvroTypeInjector extends AnnotationTypeInjector {
7773
case c: ScClass if avroAnnotation(c).isDefined =>
7874
val (annotated, other) =
7975
generatedCaseClasses(source.getQualifiedName.init, c).partition(
80-
_.contains(CaseClassSuper)
76+
_.contains(HasAvroAnnotationSuper)
8177
)
8278
(c, (annotated.headOption, other))
8379
}

src/main/scala/com/spotify/scio/BigQueryTypeInjector.scala

+12-18
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ object BigQueryTypeInjector {
3939
s"$BQTNamespace.toTable"
4040
)
4141

42-
private val CaseClassSuper =
42+
private val HasAnnotationSuper =
4343
"_root_.com.spotify.scio.bigquery.types.BigQueryType.HasAnnotation"
4444

4545
private def bqAnnotation(sc: ScClass): Option[String] =
@@ -76,14 +76,12 @@ object BigQueryTypeInjector {
7676
annotation match {
7777
case a if a.contains(FromQuery) =>
7878
val simple = """
79-
|def query: _root_.java.lang.String = ???
8079
|def queryRaw: _root_.java.lang.String = ???
8180
|""".stripMargin
8281

8382
bqQuerySignature(c)
8483
.map { params =>
8584
simple + s"""
86-
|def query($params): _root_.java.lang.String = ???
8785
|def queryAsSource($params): _root_.com.spotify.scio.bigquery.Query = ???
8886
|""".stripMargin
8987
}
@@ -110,25 +108,21 @@ final class BigQueryTypeInjector extends AnnotationTypeInjector {
110108
override def injectFunctions(source: ScTypeDefinition): Seq[String] =
111109
source match {
112110
case c: ScClass if bqAnnotation(c).isDefined =>
113-
val result = for {
114-
cc <- Option(c.containingClass)
115-
qn <- Option(cc.getQualifiedName)
111+
val fields = for {
112+
cc <- Option(c.containingClass).toSeq
113+
qn <- Option(cc.getQualifiedName).toSeq
116114
parent = qn.init
117-
defs <- {
118-
generatedCaseClasses(parent, c)
119-
.find(_.contains(CaseClassSuper))
120-
.map(getApplyPropsSignature)
121-
.map(v => s"def $v = ???")
122-
}
123-
} yield defs
124-
125-
result.toSeq
126-
case _ => Seq.empty
115+
cls <- generatedCaseClasses(parent, c).find(_.contains(HasAnnotationSuper)).toSeq
116+
v <- getApplyPropsSignature(cls)
117+
} yield s"def $v = ???"
118+
CaseClassFunctions ++ fields
119+
case _ =>
120+
Seq.empty
127121
}
128122

129123
override def injectSupers(source: ScTypeDefinition): Seq[String] =
130124
source match {
131-
case c: ScClass if bqAnnotation(c).isDefined => Seq(CaseClassSuper)
125+
case c: ScClass if bqAnnotation(c).isDefined => Seq(CaseClassSuper, HasAnnotationSuper)
132126
case _ => Seq.empty
133127
}
134128

@@ -143,7 +137,7 @@ final class BigQueryTypeInjector extends AnnotationTypeInjector {
143137
case c: ScClass if bqAnnotation(c).isDefined =>
144138
val (annotated, other) =
145139
generatedCaseClasses(source.getQualifiedName.init, c).partition(
146-
_.contains(CaseClassSuper)
140+
_.contains(HasAnnotationSuper)
147141
)
148142
(c, (annotated.headOption, other))
149143
}

0 commit comments

Comments
 (0)