Skip to content

bugfix: fix mutations data loading dispatching when defer is enabled #3994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,55 +44,81 @@ private static class CallStack {
private final LockKit.ReentrantLock lock = new LockKit.ReentrantLock();
private final LevelMap expectedFetchCountPerLevel = new LevelMap();
private final LevelMap fetchCountPerLevel = new LevelMap();
private final LevelMap expectedStrategyCallsPerLevel = new LevelMap();
private final LevelMap happenedStrategyCallsPerLevel = new LevelMap();

private final LevelMap expectedExecuteObjectCallsPerLevel = new LevelMap();
private final LevelMap happenedExecuteObjectCallsPerLevel = new LevelMap();

private final LevelMap happenedOnFieldValueCallsPerLevel = new LevelMap();

private final Set<Integer> dispatchedLevels = new LinkedHashSet<>();

public CallStack() {
expectedStrategyCallsPerLevel.set(1, 1);
expectedExecuteObjectCallsPerLevel.set(1, 1);
}

void increaseExpectedFetchCount(int level, int count) {
expectedFetchCountPerLevel.increment(level, count);
}

void clearExpectedFetchCount() {
expectedFetchCountPerLevel.clear();
}

void increaseFetchCount(int level) {
fetchCountPerLevel.increment(level, 1);
}

void increaseExpectedStrategyCalls(int level, int count) {
expectedStrategyCallsPerLevel.increment(level, count);
void clearFetchCount() {
fetchCountPerLevel.clear();
}

void increaseExpectedExecuteObjectCalls(int level, int count) {
expectedExecuteObjectCallsPerLevel.increment(level, count);
}

void increaseHappenedStrategyCalls(int level) {
happenedStrategyCallsPerLevel.increment(level, 1);
void clearExpectedObjectCalls() {
expectedExecuteObjectCallsPerLevel.clear();
}

void increaseHappenedExecuteObjectCalls(int level) {
happenedExecuteObjectCallsPerLevel.increment(level, 1);
}

void clearHappenedExecuteObjectCalls() {
happenedExecuteObjectCallsPerLevel.clear();
}

void increaseHappenedOnFieldValueCalls(int level) {
happenedOnFieldValueCallsPerLevel.increment(level, 1);
}

boolean allStrategyCallsHappened(int level) {
return happenedStrategyCallsPerLevel.get(level) == expectedStrategyCallsPerLevel.get(level);
void clearHappenedOnFieldValueCalls() {
happenedOnFieldValueCallsPerLevel.clear();
}

boolean allExecuteObjectCallsHappened(int level) {
return happenedExecuteObjectCallsPerLevel.get(level) == expectedExecuteObjectCallsPerLevel.get(level);
}

boolean allOnFieldCallsHappened(int level) {
return happenedOnFieldValueCallsPerLevel.get(level) == expectedStrategyCallsPerLevel.get(level);
return happenedOnFieldValueCallsPerLevel.get(level) == expectedExecuteObjectCallsPerLevel.get(level);
}

boolean allFetchesHappened(int level) {
return fetchCountPerLevel.get(level) == expectedFetchCountPerLevel.get(level);
}

void clearDispatchLevels() {
dispatchedLevels.clear();
}

@Override
public String toString() {
return "CallStack{" +
"expectedFetchCountPerLevel=" + expectedFetchCountPerLevel +
", fetchCountPerLevel=" + fetchCountPerLevel +
", expectedStrategyCallsPerLevel=" + expectedStrategyCallsPerLevel +
", happenedStrategyCallsPerLevel=" + happenedStrategyCallsPerLevel +
", expectedExecuteObjectCallsPerLevel=" + expectedExecuteObjectCallsPerLevel +
", happenedExecuteObjectCallsPerLevel=" + happenedExecuteObjectCallsPerLevel +
", happenedOnFieldValueCallsPerLevel=" + happenedOnFieldValueCallsPerLevel +
", dispatchedLevels" + dispatchedLevels +
'}';
Expand Down Expand Up @@ -125,16 +151,14 @@ public void executionStrategy(ExecutionContext executionContext, ExecutionStrate
return;
}
int curLevel = parameters.getExecutionStepInfo().getPath().getLevel() + 1;
increaseCallCounts(curLevel, parameters);
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, parameters);

}

@Override
public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
if (this.startedDeferredExecution.get()) {
return;
}
int curLevel = parameters.getExecutionStepInfo().getPath().getLevel() + 1;
increaseCallCounts(curLevel, parameters);
public void executionSerialStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
resetCallStack();
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(1, 1);
}

@Override
Expand All @@ -145,13 +169,24 @@ public void executionStrategyOnFieldValuesInfo(List<FieldValueInfo> fieldValueIn
onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, 1);
}

@Override
public void executionStrategyOnFieldValuesException(Throwable t) {
callStack.lock.runLocked(() ->
callStack.increaseHappenedOnFieldValueCalls(1)
);
}


@Override
public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
if (this.startedDeferredExecution.get()) {
return;
}
int curLevel = parameters.getExecutionStepInfo().getPath().getLevel() + 1;
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, parameters);
}



@Override
public void executeObjectOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, ExecutionStrategyParameters parameters) {
if (this.startedDeferredExecution.get()) {
Expand All @@ -170,45 +205,34 @@ public void executeObjectOnFieldValuesException(Throwable t, ExecutionStrategyPa
);
}

@Override
public void fieldFetched(ExecutionContext executionContext,
ExecutionStrategyParameters parameters,
DataFetcher<?> dataFetcher,
Object fetchedValue) {

final boolean dispatchNeeded;

if (parameters.getField().isDeferred() || this.startedDeferredExecution.get()) {
this.startedDeferredExecution.set(true);
dispatchNeeded = true;
} else {
int level = parameters.getPath().getLevel();
dispatchNeeded = callStack.lock.callLocked(() -> {
callStack.increaseFetchCount(level);
return dispatchIfNeeded(level);
});
}

if (dispatchNeeded) {
dispatch();
}

}

private void increaseCallCounts(int curLevel, ExecutionStrategyParameters parameters) {
int count = 0;
private void increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(int curLevel, ExecutionStrategyParameters parameters) {
int nonDeferredFields = 0;
for (MergedField field : parameters.getFields().getSubFieldsList()) {
if (!field.isDeferred()) {
count++;
nonDeferredFields++;
}
}
int nonDeferredFieldCount = count;
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, nonDeferredFields);
}

private void increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(int curLevel, int fieldCount) {
callStack.lock.runLocked(() -> {
callStack.increaseExpectedFetchCount(curLevel, nonDeferredFieldCount);
callStack.increaseHappenedStrategyCalls(curLevel);
callStack.increaseHappenedExecuteObjectCalls(curLevel);
callStack.increaseExpectedFetchCount(curLevel, fieldCount);
});
}

private void resetCallStack() {
callStack.lock.runLocked(() -> {
callStack.clearDispatchLevels();
callStack.clearExpectedObjectCalls();
callStack.clearExpectedFetchCount();
callStack.clearFetchCount();
callStack.clearHappenedExecuteObjectCalls();
callStack.clearHappenedOnFieldValueCalls();
callStack.expectedExecuteObjectCallsPerLevel.set(1, 1);
});
}
private void onFieldValuesInfoDispatchIfNeeded(List<FieldValueInfo> fieldValueInfoList, int curLevel) {
boolean dispatchNeeded = callStack.lock.callLocked(() ->
handleOnFieldValuesInfo(fieldValueInfoList, curLevel)
Expand All @@ -223,23 +247,53 @@ private void onFieldValuesInfoDispatchIfNeeded(List<FieldValueInfo> fieldValueIn
//
private boolean handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfos, int curLevel) {
callStack.increaseHappenedOnFieldValueCalls(curLevel);
int expectedStrategyCalls = getCountForList(fieldValueInfos);
callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls);
int expectedStrategyCalls = getObjectCountForList(fieldValueInfos);
callStack.increaseExpectedExecuteObjectCalls(curLevel + 1, expectedStrategyCalls);
return dispatchIfNeeded(curLevel + 1);
}

private int getCountForList(List<FieldValueInfo> fieldValueInfos) {
/**
* the amount of (non nullable) objects that will require an execute object call
*/
private int getObjectCountForList(List<FieldValueInfo> fieldValueInfos) {
int result = 0;
for (FieldValueInfo fieldValueInfo : fieldValueInfos) {
if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) {
result += 1;
} else if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) {
result += getCountForList(fieldValueInfo.getFieldValueInfos());
result += getObjectCountForList(fieldValueInfo.getFieldValueInfos());
}
}
return result;
}


@Override
public void fieldFetched(ExecutionContext executionContext,
ExecutionStrategyParameters executionStrategyParameters,
DataFetcher<?> dataFetcher,
Object fetchedValue) {

final boolean dispatchNeeded;

if (executionStrategyParameters.getField().isDeferred() || this.startedDeferredExecution.get()) {
this.startedDeferredExecution.set(true);
dispatchNeeded = true;
} else {
int level = executionStrategyParameters.getPath().getLevel();
dispatchNeeded = callStack.lock.callLocked(() -> {
callStack.increaseFetchCount(level);
return dispatchIfNeeded(level);
});
}

if (dispatchNeeded) {
dispatch();
}

}


//
// thread safety : called with callStack.lock
//
Expand All @@ -260,7 +314,7 @@ private boolean levelReady(int level) {
return callStack.allFetchesHappened(1);
}
if (levelReady(level - 1) && callStack.allOnFieldCallsHappened(level - 1)
&& callStack.allStrategyCallsHappened(level) && callStack.allFetchesHappened(level)) {
&& callStack.allExecuteObjectCallsHappened(level) && callStack.allFetchesHappened(level)) {

return true;
}
Expand Down
20 changes: 16 additions & 4 deletions src/test/groovy/graphql/MutationTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -141,23 +141,26 @@ class MutationTest extends Specification {
]])

def graphQL = GraphQL.newGraphQL(schema).build()

when:
def er = graphQL.execute("""
def ei = ExecutionInput.newExecutionInput("""
mutation m {
plus1(arg:10)
plus2(arg:10)
plus3(arg:10)
}
""")
""").build()
ei.getGraphQLContext().put(ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT, defeEnabled)

when:
def er = graphQL.execute(ei)
then:
er.errors.isEmpty()
er.data == [
plus1: 11,
plus2: 12,
plus3: 13,
]
where:
defeEnabled << [true, false]
}

def "simple async mutation with DataLoader"() {
Expand Down Expand Up @@ -213,6 +216,7 @@ class MutationTest extends Specification {
plus3(arg:10)
}
""").dataLoaderRegistry(dlReg).build()
ei.getGraphQLContext().put(ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT, defeEnabled)
when:
def er = graphQL.execute(ei)

Expand All @@ -223,12 +227,16 @@ class MutationTest extends Specification {
plus2: 12,
plus3: 13,
]

where:
defeEnabled << [true, false]
}

/*
This test shows a dataloader being called at the mutation field level, in serial via AsyncSerialExecutionStrategy, and then
again at the sub field level, in parallel, via AsyncExecutionStrategy.
*/

def "more complex async mutation with DataLoader"() {
def sdl = """
type Query {
Expand Down Expand Up @@ -436,6 +444,7 @@ class MutationTest extends Specification {
}
}
""").dataLoaderRegistry(dlReg).build()
ei.getGraphQLContext().put(ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT, defeEnabled)
when:
def cf = graphQL.executeAsync(ei)

Expand All @@ -459,5 +468,8 @@ class MutationTest extends Specification {
topLevelF3: expectedMap,
topLevelF4: expectedMap,
]

where:
defeEnabled << [true, false]
}
}
Loading