1
1
namespace AgileObjects . ReadableExpressions . Translations ;
2
2
3
+ using System ;
3
4
#if NET35
4
5
using Microsoft . Scripting . Ast ;
5
- using static Microsoft . Scripting . Ast . ExpressionType ;
6
6
#else
7
7
using System . Linq . Expressions ;
8
+ #endif
9
+ using Extensions ;
10
+ using NetStandardPolyfills ;
11
+ #if NET35
12
+ using static Microsoft . Scripting . Ast . ExpressionType ;
13
+ #else
8
14
using static System . Linq . Expressions . ExpressionType ;
9
15
#endif
10
16
@@ -124,17 +130,75 @@ public static INodeTranslation For(
124
130
125
131
case Equal :
126
132
case NotEqual :
127
- if ( StandaloneEqualityComparisonTranslation . TryGetTranslation ( binary , context , out var translation ) )
133
+ if ( BoolEqualityComparisonTranslation . TryGetTranslation ( binary , context , out var boolComparison ) )
128
134
{
129
- return translation ;
135
+ return boolComparison ;
130
136
}
131
137
138
+ goto default ;
139
+
140
+ default :
141
+ TryGetEnumComparisonExpression ( ref binary ) ;
132
142
break ;
133
143
}
134
144
135
145
return new BinaryTranslation ( binary , context ) ;
136
146
}
137
147
148
+ public static void TryGetEnumComparisonExpression (
149
+ ref BinaryExpression comparison )
150
+ {
151
+ var leftOperandIsEnum =
152
+ IsEnumType ( comparison . Left , out var leftExpression ) ;
153
+
154
+ var rightOperandIsEnum =
155
+ IsEnumType ( comparison . Right , out var rightExpression ) ;
156
+
157
+ if ( leftOperandIsEnum || rightOperandIsEnum )
158
+ {
159
+ var enumType = leftOperandIsEnum
160
+ ? leftExpression . Type : rightExpression . Type ;
161
+
162
+ comparison = comparison . Update (
163
+ GetEnumValue ( leftExpression , enumType ) ,
164
+ comparison . Conversion ,
165
+ GetEnumValue ( rightExpression , enumType ) ) ;
166
+ }
167
+ }
168
+
169
+ private static bool IsEnumType (
170
+ Expression expression ,
171
+ out Expression enumExpression )
172
+ {
173
+ if ( expression . NodeType . IsCast ( ) )
174
+ {
175
+ expression = expression . GetUnaryOperand ( ) ;
176
+ }
177
+
178
+ if ( expression . Type . GetNonNullableType ( ) . IsEnum ( ) )
179
+ {
180
+ enumExpression = expression ;
181
+ return true ;
182
+ }
183
+
184
+ enumExpression = expression ;
185
+ return false ;
186
+ }
187
+
188
+ private static Expression GetEnumValue (
189
+ Expression expression ,
190
+ Type enumType )
191
+ {
192
+ if ( expression . NodeType != Constant )
193
+ {
194
+ return expression ;
195
+ }
196
+
197
+ var value = ( ( ConstantExpression ) expression ) . Value ;
198
+ var enumValue = Enum . Parse ( enumType , value . ToString ( ) ) ;
199
+ return Expression . Constant ( enumValue , enumType ) ;
200
+ }
201
+
138
202
#endregion
139
203
140
204
public static bool IsBinary ( ExpressionType nodeType )
@@ -173,13 +237,13 @@ public void WriteTo(TranslationWriter writer)
173
237
protected override bool IsMultiStatement ( )
174
238
=> _leftOperandTranslation . IsMultiStatement ( ) || _rightOperandTranslation . IsMultiStatement ( ) ;
175
239
176
- private class StandaloneEqualityComparisonTranslation : INodeTranslation
240
+ private class BoolEqualityComparisonTranslation : INodeTranslation
177
241
{
178
242
private readonly ITranslationContext _context ;
179
243
private readonly StandaloneBoolean _standaloneBoolean ;
180
244
private readonly INodeTranslation _operandTranslation ;
181
245
182
- private StandaloneEqualityComparisonTranslation (
246
+ private BoolEqualityComparisonTranslation (
183
247
ExpressionType nodeType ,
184
248
Expression boolean ,
185
249
ExpressionType @operator ,
@@ -199,7 +263,7 @@ public static bool TryGetTranslation(
199
263
{
200
264
if ( IsBooleanConstant ( comparison . Right ) )
201
265
{
202
- translation = new StandaloneEqualityComparisonTranslation (
266
+ translation = new BoolEqualityComparisonTranslation (
203
267
comparison . NodeType ,
204
268
comparison . Left ,
205
269
comparison . NodeType ,
@@ -211,7 +275,7 @@ public static bool TryGetTranslation(
211
275
212
276
if ( IsBooleanConstant ( comparison . Left ) )
213
277
{
214
- translation = new StandaloneEqualityComparisonTranslation (
278
+ translation = new BoolEqualityComparisonTranslation (
215
279
comparison . NodeType ,
216
280
comparison . Right ,
217
281
comparison . NodeType ,
0 commit comments