Skip to content

Add support for nullable fields in deep equals #768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,12 @@ sealed class UtReferenceModel(
) : UtModel(classId)

/**
* Checks if [UtModel] is a null.
* Checks if [UtModel] is a [UtNullModel].
*/
fun UtModel.isNull() = this is UtNullModel

/**
* Checks if [UtModel] is not a null.
* Checks if [UtModel] is not a [UtNullModel].
*/
fun UtModel.isNotNull() = !isNull()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class CodeGenerator(
testFramework = testFramework,
mockFramework = mockFramework ?: MockFramework.MOCKITO,
codegenLanguage = codegenLanguage,
parameterizedTestSource = parameterizedTestSource,
parametrizedTestSource = parameterizedTestSource,
staticsMocking = staticsMocking,
forceStaticMocking = forceStaticMocking,
generateWarningsForStaticMocking = generateWarningsForStaticMocking,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ internal interface CgContextOwner {

val codegenLanguage: CodegenLanguage

val parameterizedTestSource: ParametrizedTestSource
val parametrizedTestSource: ParametrizedTestSource

// flag indicating whether a mock framework is used in the generated code
var mockFrameworkUsed: Boolean
Expand Down Expand Up @@ -214,6 +214,8 @@ internal interface CgContextOwner {

var statesCache: EnvironmentFieldStateCache

var allExecutions: List<UtExecution>

fun block(init: () -> Unit): Block {
val prevBlock = currentBlock
return try {
Expand Down Expand Up @@ -407,7 +409,7 @@ internal data class CgContext(
override val forceStaticMocking: ForceStaticMocking,
override val generateWarningsForStaticMocking: Boolean,
override val codegenLanguage: CodegenLanguage = CodegenLanguage.defaultItem,
override val parameterizedTestSource: ParametrizedTestSource = ParametrizedTestSource.DO_NOT_PARAMETRIZE,
override val parametrizedTestSource: ParametrizedTestSource = ParametrizedTestSource.DO_NOT_PARAMETRIZE,
override var mockFrameworkUsed: Boolean = false,
override var currentBlock: PersistentList<CgStatement> = persistentListOf(),
override var existingVariableNames: PersistentSet<String> = persistentSetOf(),
Expand All @@ -427,6 +429,7 @@ internal data class CgContext(
) : CgContextOwner {
override lateinit var statesCache: EnvironmentFieldStateCache
override lateinit var actual: CgVariable
override lateinit var allExecutions: List<UtExecution>

/**
* This property cannot be accessed outside of test class file scope
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import org.utbot.framework.codegen.model.tree.CgExecutableCall
import org.utbot.framework.codegen.model.tree.CgExpression
import org.utbot.framework.codegen.model.tree.CgFieldAccess
import org.utbot.framework.codegen.model.tree.CgGetJavaClass
import org.utbot.framework.codegen.model.tree.CgIsInstance
import org.utbot.framework.codegen.model.tree.CgLiteral
import org.utbot.framework.codegen.model.tree.CgMethod
import org.utbot.framework.codegen.model.tree.CgMethodCall
Expand Down Expand Up @@ -111,9 +110,10 @@ import org.utbot.framework.plugin.api.UtStaticMethodInstrumentation
import org.utbot.framework.plugin.api.UtSymbolicExecution
import org.utbot.framework.plugin.api.UtTimeoutException
import org.utbot.framework.plugin.api.UtVoidModel
import org.utbot.framework.plugin.api.isNotNull
import org.utbot.framework.plugin.api.isNull
import org.utbot.framework.plugin.api.onFailure
import org.utbot.framework.plugin.api.onSuccess
import org.utbot.framework.plugin.api.util.booleanClassId
import org.utbot.framework.plugin.api.util.doubleArrayClassId
import org.utbot.framework.plugin.api.util.doubleClassId
import org.utbot.framework.plugin.api.util.doubleWrapperClassId
Expand Down Expand Up @@ -144,7 +144,6 @@ import org.utbot.summary.SummarySentenceConstants.TAB
import java.lang.reflect.InvocationTargetException
import java.security.AccessControlException
import java.lang.reflect.ParameterizedType
import kotlin.reflect.jvm.javaType

private const val DEEP_EQUALS_MAX_DEPTH = 5 // TODO move it to plugin settings?

Expand All @@ -168,6 +167,8 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c

private lateinit var methodType: CgTestMethodType

private val fieldsOfExecutionResults = mutableMapOf<Pair<FieldId, Int>, MutableList<UtModel>>()

private fun setupInstrumentation() {
if (currentExecution is UtSymbolicExecution) {
val execution = currentExecution as UtSymbolicExecution
Expand Down Expand Up @@ -445,7 +446,6 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
val expectedExpression = CgNotNullAssertion(expectedVariable)

assertEquality(expectedExpression, actual)
println()
}
}
.onFailure { thisInstance[method](*methodArguments.toTypedArray()).intercepted() }
Expand Down Expand Up @@ -537,7 +537,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
doubleDelta
)
expectedModel.value is Boolean -> {
when (parameterizedTestSource) {
when (parametrizedTestSource) {
ParametrizedTestSource.DO_NOT_PARAMETRIZE ->
if (expectedModel.value as Boolean) {
assertions[assertTrue](actual)
Expand Down Expand Up @@ -842,6 +842,25 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
return
}

when (parametrizedTestSource) {
ParametrizedTestSource.DO_NOT_PARAMETRIZE -> {
traverseField(fieldId, fieldModel, expected, actual, depth, visitedModels)
}

ParametrizedTestSource.PARAMETRIZE -> {
traverseFieldForParametrizedTest(fieldId, fieldModel, expected, actual, depth, visitedModels)
}
}
}

private fun traverseField(
fieldId: FieldId,
fieldModel: UtModel,
expected: CgVariable,
actual: CgVariable,
depth: Int,
visitedModels: MutableSet<UtModel>
) {
// fieldModel is not visited and will be marked in assertDeepEquals call
val fieldName = fieldId.name
var expectedVariable: CgVariable? = null
Expand All @@ -866,6 +885,140 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
emptyLineIfNeeded()
}

private fun traverseFieldForParametrizedTest(
fieldId: FieldId,
fieldModel: UtModel,
expected: CgVariable,
actual: CgVariable,
depth: Int,
visitedModels: MutableSet<UtModel>
) {
val fieldResultModels = fieldsOfExecutionResults[fieldId to depth]
val nullResultModelInExecutions = fieldResultModels?.find { it.isNull() }
val notNullResultModelInExecutions = fieldResultModels?.find { it.isNotNull() }

val hasNullResultModel = nullResultModelInExecutions != null
val hasNotNullResultModel = notNullResultModelInExecutions != null

val needToSubstituteFieldModel = fieldModel is UtNullModel && hasNotNullResultModel

val fieldModelForAssert = if (needToSubstituteFieldModel) notNullResultModelInExecutions!! else fieldModel

// fieldModel is not visited and will be marked in assertDeepEquals call
val fieldName = fieldId.name
var expectedVariable: CgVariable? = null

val needExpectedDeclaration = needExpectedDeclaration(fieldModelForAssert)
if (needExpectedDeclaration) {
val expectedFieldDeclaration = createDeclarationForFieldFromVariable(fieldId, expected, fieldName)

currentBlock += expectedFieldDeclaration
expectedVariable = expectedFieldDeclaration.variable
}

val actualFieldDeclaration = createDeclarationForFieldFromVariable(fieldId, actual, fieldName)
currentBlock += actualFieldDeclaration

if (needExpectedDeclaration && hasNullResultModel) {
ifStatement(
CgEqualTo(expectedVariable!!, nullLiteral()),
trueBranch = { +testFrameworkManager.assertions[testFramework.assertNull](actualFieldDeclaration.variable).toStatement() },
falseBranch = {
assertDeepEquals(
fieldModelForAssert,
expectedVariable,
actualFieldDeclaration.variable,
depth + 1,
visitedModels,
)
}
)
} else {
assertDeepEquals(
fieldModelForAssert,
expectedVariable,
actualFieldDeclaration.variable,
depth + 1,
visitedModels,
)
}
emptyLineIfNeeded()
}

private fun collectExecutionsResultFields() {
val successfulExecutionsModels = allExecutions
.filter {
it.result is UtExecutionSuccess
}.map {
(it.result as UtExecutionSuccess).model
}

for (model in successfulExecutionsModels) {
when (model) {
is UtCompositeModel -> {
for ((fieldId, fieldModel) in model.fields) {
collectExecutionsResultFieldsRecursively(fieldId, fieldModel, 0)
}
}

is UtAssembleModel -> {
model.origin?.let {
for ((fieldId, fieldModel) in it.fields) {
collectExecutionsResultFieldsRecursively(fieldId, fieldModel, 0)
}
}
}

is UtNullModel,
is UtPrimitiveModel,
is UtArrayModel,
is UtClassRefModel,
is UtEnumConstantModel,
is UtVoidModel -> {
// only [UtCompositeModel] and [UtAssembleModel] have fields to traverse
}
}
}
}

private fun collectExecutionsResultFieldsRecursively(
fieldId: FieldId,
fieldModel: UtModel,
depth: Int,
) {
if (depth >= DEEP_EQUALS_MAX_DEPTH) {
return
}

val fieldKey = fieldId to depth
fieldsOfExecutionResults.getOrPut(fieldKey) { mutableListOf() } += fieldModel

when (fieldModel) {
is UtCompositeModel -> {
for ((id, model) in fieldModel.fields) {
collectExecutionsResultFieldsRecursively(id, model, depth + 1)
}
}

is UtAssembleModel -> {
fieldModel.origin?.let {
for ((id, model) in it.fields) {
collectExecutionsResultFieldsRecursively(id, model, depth + 1)
}
}
}

is UtNullModel,
is UtPrimitiveModel,
is UtArrayModel,
is UtClassRefModel,
is UtEnumConstantModel,
is UtVoidModel -> {
// only [UtCompositeModel] and [UtAssembleModel] have fields to traverse
}
}
}

@Suppress("UNUSED_ANONYMOUS_PARAMETER")
private fun createDeclarationForFieldFromVariable(
fieldId: FieldId,
Expand Down Expand Up @@ -999,7 +1152,7 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
}
expected == nullLiteral() -> testFrameworkManager.assertNull(actual)
expected is CgLiteral && expected.value is Boolean -> {
when (parameterizedTestSource) {
when (parametrizedTestSource) {
ParametrizedTestSource.DO_NOT_PARAMETRIZE ->
testFrameworkManager.assertBoolean(expected.value, actual)
ParametrizedTestSource.PARAMETRIZE ->
Expand Down Expand Up @@ -1054,15 +1207,19 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
expected: CgValue,
actual: CgVariable,
) {
when (parameterizedTestSource) {
when (parametrizedTestSource) {
ParametrizedTestSource.DO_NOT_PARAMETRIZE -> generateDeepEqualsAssertion(expected, actual)
ParametrizedTestSource.PARAMETRIZE -> when {
actual.type.isPrimitive -> generateDeepEqualsAssertion(expected, actual)
else -> ifStatement(
CgEqualTo(expected, nullLiteral()),
trueBranch = { +testFrameworkManager.assertions[testFramework.assertNull](actual).toStatement() },
falseBranch = { generateDeepEqualsAssertion(expected, actual) }
)
ParametrizedTestSource.PARAMETRIZE -> {
collectExecutionsResultFields()

when {
actual.type.isPrimitive -> generateDeepEqualsAssertion(expected, actual)
else -> ifStatement(
CgEqualTo(expected, nullLiteral()),
trueBranch = { +testFrameworkManager.assertions[testFramework.assertNull](actual).toStatement() },
falseBranch = { generateDeepEqualsAssertion(expected, actual) }
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,13 @@ internal class CgTestClassConstructor(val context: CgContext) :
return null
}

allExecutions = testSet.executions

val (methodUnderTest, _, _, clustersInfo) = testSet
val regions = mutableListOf<CgRegion<CgMethod>>()
val requiredFields = mutableListOf<CgParameterDeclaration>()

when (context.parameterizedTestSource) {
when (context.parametrizedTestSource) {
ParametrizedTestSource.DO_NOT_PARAMETRIZE -> {
for ((clusterSummary, executionIndices) in clustersInfo) {
val currentTestCaseTestMethods = mutableListOf<CgTestMethod>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,13 @@ class ClassWithNullableFieldTest : UtValueTestCaseChecker(
coverage = DoNotCalculate
)
}

@Test
fun testClassWithNullableField1() {
check(
ClassWithNullableField::returnGreatCompoundWithNullableField,
eq(3),
coverage = DoNotCalculate
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,23 @@ class Compound {
}
}

class GreatCompound {
Compound compound;

GreatCompound(Compound compound) {
this.compound = compound;
}
}

public class ClassWithNullableField {
public Compound returnCompoundWithNullableField(int value) {
if (value > 0) return new Compound(null);
else return new Compound(new Component());
}

public GreatCompound returnGreatCompoundWithNullableField(int value) {
if (value > 0) return new GreatCompound(null);
else if (value == 0) return new GreatCompound(new Compound(new Component()));
else return new GreatCompound(new Compound(null));
}
}