Skip to content

Commit

Permalink
[python] First step in supporting type inference, based on the usage …
Browse files Browse the repository at this point in the history
…of `sq_concat` (#207)

* Add sq_concat support

* Add "list_concat_usage" test

* Changed test to detect PyMockTypeStream issues

* Add sq_concat notification

* fix: more accurate nbAddKt function

* fix linter
  • Loading branch information
jefremof authored Aug 16, 2024
1 parent 2a277d2 commit 37050f9
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ JNIEXPORT jint JNICALL Java_org_usvm_interpreter_CPythonAdapter_typeHasNbPositiv
return type->tp_as_number && type->tp_as_number->nb_positive;
}

JNIEXPORT jint JNICALL Java_org_usvm_interpreter_CPythonAdapter_typeHasSqConcat(JNIEnv *env, jobject _, jlong type_ref) {
QUERY_TYPE_HAS_PREFIX
return type->tp_as_sequence && type->tp_as_sequence->sq_concat;
}

JNIEXPORT jint JNICALL Java_org_usvm_interpreter_CPythonAdapter_typeHasSqLength(JNIEnv *env, jobject _, jlong type_ref) {
QUERY_TYPE_HAS_PREFIX
return type->tp_as_sequence && type->tp_as_sequence->sq_length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ val sampleStringFunction = StringProgramProvider(
/**
* Sample of a function that cannot be covered right now.
* */
val listConcatProgram = StringProgramProvider(
val tupleConcatProgram = StringProgramProvider(
"""
def list_concat(x):
y = x + [1]
if len(y[::-1]) == 5:
return 1
return 2
def tuple_concat(x, y):
z = x + y
return z + (1, 2, 3)
""".trimIndent(),
"list_concat",
) { listOf(PythonAnyType) }
"tuple_concat",
) { listOf(PythonAnyType, PythonAnyType) }
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ class SimpleTypeInferenceTest: PythonTestRunnerForPrimitiveProgram("SimpleTypeIn
)
}

@Test
fun testListConcatUsage() {
check2WithConcreteRun(
constructFunction("list_concat_usage", List(2) { PythonAnyType }),
ignoreNumberOfAnalysisResults,
standardConcolicAndConcreteChecks,
/* invariants = */ emptyList(),
/* propertiesToDiscover = */ listOf(
{ _, _, res -> res.selfTypeName == "AssertionError" },
{ _, _, res -> res.repr == "None" }
)
)
}

@Test
fun testLenUsage() {
check1WithConcreteRun(
Expand Down
6 changes: 6 additions & 0 deletions usvm-python/src/test/resources/samples/SimpleTypeInference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def isinstance_sample(x):
return "Not reachable"


def list_concat_usage(x, y):
z = x + y
z += []
assert z


def len_usage(x):
if len(x) == 5:
return 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class CPythonAdapter {
public native int typeHasNbMatrixMultiply(long type);
public native int typeHasNbNegative(long type);
public native int typeHasNbPositive(long type);
public native int typeHasSqConcat(long type);
public native int typeHasSqLength(long type);
public native int typeHasMpLength(long type);
public native int typeHasMpSubscript(long type);
Expand Down Expand Up @@ -1071,6 +1072,18 @@ public static void notifyNbPositive(ConcolicRunContext context, SymbolForCPython
nbPositiveKt(context, on.obj);
}

@CPythonAdapterJavaMethod(cName = "sq_concat")
@CPythonFunction(
argCTypes = {CType.PyObject, CType.PyObject},
argConverters = {ObjectConverter.StandardConverter, ObjectConverter.StandardConverter}
)
public static void notifySqConcat(ConcolicRunContext context, SymbolForCPython left, SymbolForCPython right) {
if (left.obj == null || right.obj == null)
return;
context.curOperation = new MockHeader(SqConcatMethod.INSTANCE, Arrays.asList(left.obj, right.obj), null);
sqConcatKt(context, left.obj, right.obj);
}

@CPythonAdapterJavaMethod(cName = "sq_length")
@CPythonFunction(
argCTypes = {CType.PyObject},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ data object NbMultiplyMethod : TypeMethod(false)
data object NbMatrixMultiplyMethod : TypeMethod(false)
data object NbNegativeMethod : TypeMethod(false)
data object NbPositiveMethod : TypeMethod(false)
data object SqConcatMethod : TypeMethod(false)
data object SqLengthMethod : TypeMethod(true)
data object MpSubscriptMethod : TypeMethod(false)
data object MpAssSubscriptMethod : TypeMethod(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ object ConcretePythonInterpreter {
val typeHasNbMatrixMultiply = createTypeQuery { pythonAdapter.typeHasNbMatrixMultiply(it) }
val typeHasNbNegative = createTypeQuery { pythonAdapter.typeHasNbNegative(it) }
val typeHasNbPositive = createTypeQuery { pythonAdapter.typeHasNbPositive(it) }
val typeHasSqConcat = createTypeQuery { pythonAdapter.typeHasSqConcat(it) }
val typeHasSqLength = createTypeQuery { pythonAdapter.typeHasSqLength(it) }
val typeHasMpLength = createTypeQuery { pythonAdapter.typeHasMpLength(it) }
val typeHasMpSubscript = createTypeQuery { pythonAdapter.typeHasMpSubscript(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import org.usvm.machine.types.HasNbMultiply
import org.usvm.machine.types.HasNbNegative
import org.usvm.machine.types.HasNbPositive
import org.usvm.machine.types.HasNbSubtract
import org.usvm.machine.types.HasSqConcat
import org.usvm.machine.types.HasSqLength
import org.usvm.machine.types.HasTpCall
import org.usvm.machine.types.HasTpGetattro
Expand All @@ -37,7 +38,17 @@ fun nbAddKt(
context.ctx
) {
context.curState ?: return
pyAssert(context, left.evalIsSoft(context, HasNbAdd) or right.evalIsSoft(context, HasNbAdd))
/*
* The __add__ method corresponds both to the nb_add and sq_concat slots,
* so it is crucial not to assert the presence of nb_add, but to fork on these
* two possible options.
* Moreover, for now it was decided, that operation `sq_concat` makes sense
* only in the situation, when both operands have the corresponding slot.
*/
val nbAdd = left.evalIsSoft(context, HasNbAdd) or right.evalIsSoft(context, HasNbAdd)
val sqConcat = left.evalIsSoft(context, HasSqConcat) and right.evalIsSoft(context, HasSqConcat)
pyAssert(context, nbAdd.not() implies sqConcat)
pyFork(context, nbAdd)
}

fun nbSubtractKt(
Expand Down Expand Up @@ -74,6 +85,17 @@ fun nbPositiveKt(context: ConcolicRunContext, on: UninterpretedSymbolicPythonObj
pyAssert(context, on.evalIsSoft(context, HasNbPositive))
}

fun sqConcatKt(
context: ConcolicRunContext,
left: UninterpretedSymbolicPythonObject,
right: UninterpretedSymbolicPythonObject,
) = with(
context.ctx
) {
context.curState ?: return
pyAssert(context, left.evalIsSoft(context, HasSqConcat) and right.evalIsSoft(context, HasSqConcat))
}

fun sqLengthKt(context: ConcolicRunContext, on: UninterpretedSymbolicPythonObject) {
context.curState ?: return
val sqLength = on.evalIsSoft(context, HasSqLength)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.usvm.language.NbMultiplyMethod
import org.usvm.language.NbNegativeMethod
import org.usvm.language.NbPositiveMethod
import org.usvm.language.NbSubtractMethod
import org.usvm.language.SqConcatMethod
import org.usvm.language.SqLengthMethod
import org.usvm.language.TpCallMethod
import org.usvm.language.TpGetattro
Expand Down Expand Up @@ -95,6 +96,10 @@ class SymbolTypeTree(
listOf(createBinaryProtocol("__mul__", pythonAnyType, returnType))
}

SqConcatMethod -> { returnType: UtType ->
listOf(createBinaryProtocol("__add__", pythonAnyType, returnType))
}

SqLengthMethod -> { _: UtType ->
listOf(createUnaryProtocol("__len__", typeHintsStorage.pythonInt))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ object HasNbPositive : TypeProtocol() {
ConcretePythonInterpreter.typeHasNbPositive(type.asObject)
}

object HasSqConcat : TypeProtocol() {
override fun acceptsConcrete(type: ConcretePythonType): Boolean =
ConcretePythonInterpreter.typeHasSqConcat(type.asObject)
}

object HasSqLength : TypeProtocol() {
override fun acceptsConcrete(type: ConcretePythonType): Boolean =
ConcretePythonInterpreter.typeHasSqLength(type.asObject)
Expand Down

0 comments on commit 37050f9

Please sign in to comment.