Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated AuthorizationValidationRule to skip authorization checks when… #30

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,54 @@ public void fails_on_missing_claim_on_connection_type()
});
}

[Fact]
public void passes_when_field_is_not_included()
{
Settings.AddPolicy("FieldPolicy", _ => _.RequireClaim("admin"));

ShouldPassRule(_ =>
{
_.Query = @"query { post @include(if: false) }";
_.Schema = BasicSchema();
});
}

[Fact]
public void fails_when_field_is_included()
{
Settings.AddPolicy("FieldPolicy", _ => _.RequireClaim("admin"));

ShouldFailRule(_ =>
{
_.Query = @"query { post @include(if: true) }";
_.Schema = BasicSchema();
});
}

[Fact]
public void passes_when_field_is_skipped()
{
Settings.AddPolicy("FieldPolicy", _ => _.RequireClaim("admin"));

ShouldPassRule(_ =>
{
_.Query = @"query { post @skip(if: true) }";
_.Schema = BasicSchema();
});
}

[Fact]
public void fails_when_field_is_not_skipped()
{
Settings.AddPolicy("FieldPolicy", _ => _.RequireClaim("admin"));

ShouldFailRule(_ =>
{
_.Query = @"query { post @skip(if: false) }";
_.Schema = BasicSchema();
});
}

private static ISchema BasicSchema()
{
string defs = @"
Expand Down
44 changes: 43 additions & 1 deletion src/GraphQL.Authorization/AuthorizationValidationRule.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Linq;
using System.Threading.Tasks;
using GraphQL.Execution;
using GraphQL.Language.AST;
using GraphQL.Types;
using GraphQL.Validation;
Expand Down Expand Up @@ -56,7 +58,7 @@ public Task<INodeVisitor> ValidateAsync(ValidationContext context)
{
var fieldDef = context.TypeInfo.GetFieldDef();

if (fieldDef == null)
if (fieldDef == null || SkipAuthCheck(fieldAst, context))
return;

// check target field
Expand All @@ -67,6 +69,46 @@ public Task<INodeVisitor> ValidateAsync(ValidationContext context)
));
}

private bool SkipAuthCheck(Field field, ValidationContext context)
{
if (field.Directives == null || !field.Directives.Any())
return false;

var operationName = context.OperationName;
var documentOperations = context.Document.Operations;
var operation = !string.IsNullOrWhiteSpace(operationName)
? documentOperations.WithName(operationName)
: documentOperations.FirstOrDefault();
var variables = ExecutionHelper.GetVariableValues(context.Document, context.Schema,
operation?.Variables, context.Inputs);

var includeField = GetDirectiveValue(context, field.Directives, DirectiveGraphType.Include, variables);
if (includeField.HasValue)
return !includeField.Value;

var skipField = GetDirectiveValue(context, field.Directives, DirectiveGraphType.Skip, variables);
if (skipField.HasValue)
return skipField.Value;

return false;
}

private static bool? GetDirectiveValue(ValidationContext context, Directives directives, DirectiveGraphType directiveType, Variables variables)
{
var directive = directives.Find(directiveType.Name);
if (directive == null)
return null;

var argumentValues = ExecutionHelper.GetArgumentValues(
context.Schema,
directiveType.Arguments,
directive.Arguments,
variables);

argumentValues.TryGetValue("if", out object ifObj);
return bool.TryParse(ifObj?.ToString() ?? string.Empty, out bool ifVal) && ifVal;
}

private void CheckAuth(
INode node,
IProvideMetadata provider,
Expand Down