diff --git a/console/src/main/scala/io/joern/console/cpgcreation/PythonSrcCpgGenerator.scala b/console/src/main/scala/io/joern/console/cpgcreation/PythonSrcCpgGenerator.scala index 919ce9e51eb2..fe4e522fc76e 100644 --- a/console/src/main/scala/io/joern/console/cpgcreation/PythonSrcCpgGenerator.scala +++ b/console/src/main/scala/io/joern/console/cpgcreation/PythonSrcCpgGenerator.scala @@ -5,7 +5,7 @@ import io.joern.pysrc2cpg.* import io.joern.x2cpg.X2Cpg import io.joern.x2cpg.passes.base.AstLinkerPass import io.joern.x2cpg.passes.callgraph.NaiveCallLinker -import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig +import io.joern.x2cpg.passes.frontend.TypeRecoveryConfig import io.shiftleft.codepropertygraph.Cpg import java.nio.file.Path @@ -32,8 +32,8 @@ case class PythonSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends new DynamicTypeHintFullNamePass(cpg).createAndApply() new PythonInheritanceNamePass(cpg).createAndApply() val typeRecoveryConfig = pyConfig match - case Some(config) => XTypeRecoveryConfig(config.typePropagationIterations, !config.disableDummyTypes) - case None => XTypeRecoveryConfig() + case Some(config) => TypeRecoveryConfig(config.typePropagationIterations, !config.disableDummyTypes) + case None => TypeRecoveryConfig() new PythonTypeRecoveryPass(cpg, typeRecoveryConfig).createAndApply() new PythonTypeHintCallLinker(cpg).createAndApply() new NaiveCallLinker(cpg).createAndApply() diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala index 3d6955b71b4e..697e3d8486f5 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala @@ -9,7 +9,7 @@ import io.joern.javasrc2cpg.passes.{ TypeInferencePass } import io.joern.x2cpg.X2Cpg.withNewEmptyCpg -import io.joern.x2cpg.passes.frontend.{MetaDataPass, TypeNodePass, XTypeRecoveryConfig} +import io.joern.x2cpg.passes.frontend.{MetaDataPass, TypeNodePass, TypeRecoveryConfig} import io.joern.x2cpg.X2CpgFrontend import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.Languages @@ -57,7 +57,7 @@ object JavaSrc2Cpg { def typeRecoveryPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = { List( - new JavaTypeRecoveryPass(cpg, XTypeRecoveryConfig(enabledDummyTypes = !config.exists(_.disableDummyTypes))), + new JavaTypeRecoveryPass(cpg, TypeRecoveryConfig(enabledDummyTypes = !config.exists(_.disableDummyTypes))), new JavaTypeHintCallLinker(cpg) ) } diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala index 8f5e5cc2bd3f..246dc99574b1 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/AstCreationPass.scala @@ -21,7 +21,7 @@ import io.joern.javasrc2cpg.util.{Delombok, SourceParser} import io.joern.javasrc2cpg.{Config, JavaSrc2Cpg} import io.joern.x2cpg.SourceFiles import io.joern.x2cpg.datastructures.Global -import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig +import io.joern.x2cpg.passes.frontend.TypeRecoveryConfig import io.joern.x2cpg.utils.dependency.DependencyResolver import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.passes.ConcurrentWriterCpgPass diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/JavaTypeRecoveryPass.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/JavaTypeRecoveryPass.scala index d51e25d510ff..c754f80760b7 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/JavaTypeRecoveryPass.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/JavaTypeRecoveryPass.scala @@ -1,60 +1,86 @@ package io.joern.javasrc2cpg.passes import io.joern.x2cpg.Defines -import io.joern.x2cpg.passes.frontend._ +import io.joern.x2cpg.passes.frontend.* +import io.joern.x2cpg.passes.frontend.ImportsPass.ResolvedImport import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames -import io.shiftleft.codepropertygraph.generated.nodes._ -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.codepropertygraph.generated.nodes.* import overflowdb.BatchedUpdate.DiffGraphBuilder +import overflowdb.traversal.ImplicitsTmp.toTraversalSugarExt -class JavaTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) - extends XTypeRecoveryPass[Method](cpg, config) { - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[Method] = - new JavaTypeRecovery(cpg, state) -} - -private class JavaTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[Method](cpg, state) { - - override def compilationUnit: Iterator[Method] = cpg.method.isExternal(false).iterator +import java.util.concurrent.ExecutorService - override def generateRecoveryForCompilationUnitTask( - unit: Method, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[Method] = { - val newConfig = state.config.copy(enabledDummyTypes = state.isFinalIteration && state.config.enabledDummyTypes) - new RecoverForJavaFile(cpg, unit, builder, state.copy(config = newConfig)) - } +class JavaTypeRecoveryPass(cpg: Cpg, config: TypeRecoveryConfig = TypeRecoveryConfig()) + extends XTypeRecoveryPass(cpg, config) { + override protected def generateRecoveryPass(state: TypeRecoveryState, executor: ExecutorService): XTypeRecovery = + new JavaTypeRecovery(cpg, state, executor) } -private class RecoverForJavaFile(cpg: Cpg, cu: Method, builder: DiffGraphBuilder, state: XTypeRecoveryState) - extends RecoverForXCompilationUnit[Method](cpg, cu, builder, state) { +private class JavaTypeRecovery(cpg: Cpg, state: TypeRecoveryState, executor: ExecutorService) + extends XTypeRecovery(cpg, state, executor) { + + override protected val initialSymbolTable = new SymbolTable[LocalKey](javaNodeToLocalKey) private def javaNodeToLocalKey(n: AstNode): Option[LocalKey] = n match { case i: Identifier if i.name == "this" && i.code == "super" => Option(LocalVar("super")) case _ => SBKey.fromNodeToLocalKey(n) } - override protected val symbolTable = new SymbolTable[LocalKey](javaNodeToLocalKey) + override protected def recoverTypesForProcedure( + cpg: Cpg, + procedure: Method, + initialSymbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState + ): RecoverTypesForProcedure = + RecoverForJavaFile(cpg, procedure, initialSymbolTable, builder, state) - override protected def isConstructor(c: Call): Boolean = isConstructor(c.name) + override protected def importNodes(cu: File): List[ResolvedImport] = + cu.namespaceBlock.flatMap(_.astOut).collectAll[Import].flatMap(visitImport).l - override protected def isConstructor(name: String): Boolean = !name.isBlank && name.charAt(0).isUpper + // Java has a much simpler import structure that doesn't need resolution + override protected def visitImport(i: Import): Iterator[ImportsPass.ResolvedImport] = { + for { + alias <- i.importedAs + fullName <- i.importedEntity + } { + if (alias != "*") { + initialSymbolTable.append(CallAlias(alias, Option("this")), fullName) + initialSymbolTable.append(LocalVar(alias), fullName) + } + } + Iterator.empty + } override protected def postVisitImports(): Unit = { - symbolTable.view.foreach { case (k, ts) => + initialSymbolTable.view.foreach { case (k, ts) => val tss = ts.filterNot(_.startsWith(Defines.UnresolvedNamespace)) if (tss.isEmpty) - symbolTable.remove(k) + initialSymbolTable.remove(k) else - symbolTable.put(k, tss) + initialSymbolTable.put(k, tss) } } +} + +private class RecoverForJavaFile( + cpg: Cpg, + procedure: Method, + symbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState +) extends RecoverTypesForProcedure(cpg, procedure, symbolTable, builder, state) { + + override protected def isConstructor(c: Call): Boolean = isConstructor(c.name) + + override protected def isConstructor(name: String): Boolean = !name.isBlank && name.charAt(0).isUpper + // There seems to be issues with inferring these, often due to situations where super and this are confused on name // and code properties. override protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = if (i.name != "this") { - super.storeIdentifierTypeInfo(i, types) + super.storeIdentifierTypeInfo(i, types.filterNot(_ == "null")) } override protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = @@ -64,7 +90,9 @@ private class RecoverForJavaFile(cpg: Cpg, cu: Method, builder: DiffGraphBuilder case t if t.endsWith(c.signature) => t case t => s"$t:${c.signature}" } - builder.setNodeProperty(c, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, signedTypes) + if (c.possibleTypes != signedTypes) { + builder.setNodeProperty(c, PropertyNames.POSSIBLE_TYPES, signedTypes) + } } } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala index 5f2562cbd982..fa6a11a4853f 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala @@ -354,11 +354,16 @@ class JavaTypeRecoveryPassTests extends JavaSrcCode2CpgFixture(enableTypeRecover Seq("net", "javaguides", "hibernate", "NamedQueryExample.java").mkString(File.separator) ) - "should be resolved using dummy return values" in { + "receive a full namespace from Java inference" in { + val Some(getResultList) = cpg.call("createNamedQuery").headOption: @unchecked + getResultList.methodFullName shouldBe "org.hibernate.Session.createNamedQuery:(2)" + } + + "resolve the second call using dummy return values" in { val Some(getResultList) = cpg.call("getResultList").headOption: @unchecked // Changes the below from .getResultList:(0) to: getResultList.methodFullName shouldBe "org.hibernate.Session.createNamedQuery:(2)..getResultList:(0)" - getResultList.dynamicTypeHintFullName shouldBe Seq() + getResultList.possibleTypes shouldBe Seq() } "hint that `transaction` may be of the null type" in { diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/JsSrc2Cpg.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/JsSrc2Cpg.scala index e05d55eaa7b6..087c6bcf832e 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/JsSrc2Cpg.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/JsSrc2Cpg.scala @@ -8,7 +8,7 @@ import io.joern.jssrc2cpg.utils.AstGenRunner import io.joern.x2cpg.X2Cpg.withNewEmptyCpg import io.joern.x2cpg.X2CpgFrontend import io.joern.x2cpg.passes.callgraph.NaiveCallLinker -import io.joern.x2cpg.passes.frontend.XTypeRecoveryConfig +import io.joern.x2cpg.passes.frontend.TypeRecoveryConfig import io.joern.x2cpg.utils.{HashUtil, Report} import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.passes.CpgPassBase @@ -58,8 +58,8 @@ object JsSrc2Cpg { def postProcessingPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = { val typeRecoveryConfig = config - .map(c => XTypeRecoveryConfig(c.typePropagationIterations, !c.disableDummyTypes)) - .getOrElse(XTypeRecoveryConfig()) + .map(c => TypeRecoveryConfig(c.typePropagationIterations, !c.disableDummyTypes)) + .getOrElse(TypeRecoveryConfig()) List( new JavaScriptInheritanceNamePass(cpg), new ConstClosurePass(cpg), diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/ImportResolverPass.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/ImportResolverPass.scala index fabb914f3a1f..c346727a90be 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/ImportResolverPass.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/ImportResolverPass.scala @@ -94,9 +94,9 @@ class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { if (methodMatches.nonEmpty) methodMatches.fullName.toSet else constructorMatches.fullName.toSet if (methodPaths.nonEmpty) { - methodPaths.flatMap(x => Set(ResolvedMethod(x, alias, Option("this")), ResolvedTypeDecl(x))) + methodPaths.flatMap(x => Set(ResolvedMethod(x, alias, Option("this")), ResolvedTypeDecl(x, alias))) } else if (moduleExportsThisVariable) { - Set(ResolvedMember(targetModule.fullName.head, b.name)) + Set(ResolvedMember(targetModule.fullName.head, b.name, alias)) } else { Set.empty } @@ -108,7 +108,7 @@ class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { b.referencedMethod.astParent.iterator .collectAll[Method] .fullName - .map(x => ResolvedTypeDecl(x)) + .map(x => ResolvedTypeDecl(x, alias)) .toSet ++ Set(ResolvedMethod(b.methodFullName, callName, receiver)) case ::(_, ::(y: Call, _)) => // Exported closure with a method ref within the AST of the RHS @@ -118,7 +118,7 @@ class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { } }.toSet } else { - Set(UnknownMethod(entity, alias, Option("this")), UnknownTypeDecl(entity)) + Set(UnknownMethod(entity, alias, Option("this")), UnknownTypeDecl(entity, alias)) }).foreach(x => resolvedImportToTag(x, importCall, diffGraph)) } diff --git a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeRecovery.scala b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeRecovery.scala index a5c1ac99d18f..3e7f71c86f1a 100644 --- a/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeRecovery.scala +++ b/joern-cli/frontends/jssrc2cpg/src/main/scala/io/joern/jssrc2cpg/passes/JavaScriptTypeRecovery.scala @@ -1,37 +1,46 @@ package io.joern.jssrc2cpg.passes +import io.joern.x2cpg.Defines as XDefines import io.joern.x2cpg.Defines.ConstructorMethodName -import io.joern.x2cpg.passes.frontend._ -import io.joern.x2cpg.{Defines => XDefines} +import io.joern.x2cpg.passes.frontend.* import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess import overflowdb.BatchedUpdate.DiffGraphBuilder -class JavaScriptTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) - extends XTypeRecoveryPass[File](cpg, config) { - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = - new JavaScriptTypeRecovery(cpg, state) -} +import java.util.concurrent.ExecutorService -private class JavaScriptTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[File](cpg, state) { +class JavaScriptTypeRecoveryPass(cpg: Cpg, config: TypeRecoveryConfig = TypeRecoveryConfig()) + extends XTypeRecoveryPass(cpg, config) { + override protected def generateRecoveryPass(state: TypeRecoveryState, executor: ExecutorService): XTypeRecovery = + new JavaScriptTypeRecovery(cpg, state, executor) +} - override def compilationUnit: Iterator[File] = cpg.file.iterator +private class JavaScriptTypeRecovery(cpg: Cpg, state: TypeRecoveryState, executor: ExecutorService) + extends XTypeRecovery(cpg, state, executor) { - override def generateRecoveryForCompilationUnitTask( - unit: File, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[File] = { - val newConfig = state.config.copy(enabledDummyTypes = state.isFinalIteration && state.config.enabledDummyTypes) - new RecoverForJavaScriptFile(cpg, unit, builder, state.copy(config = newConfig)) - } + override protected def recoverTypesForProcedure( + cpg: Cpg, + procedure: Method, + initialSymbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState + ): RecoverTypesForProcedure = + RecoverForJavaScriptProcedure(cpg, procedure, initialSymbolTable, builder, state) } -private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState) - extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) { +private class RecoverForJavaScriptProcedure( + cpg: Cpg, + procedure: Method, + symbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState +) extends RecoverTypesForProcedure(cpg, procedure, symbolTable, builder, state) { + + import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt override protected val pathSep = ':' @@ -58,48 +67,50 @@ private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBui builder.setNodeProperty(x, PropertyNames.TYPE_FULL_NAME, resolvedTypeHints.head) case x @ (_: Identifier | _: Local | _: MethodParameterIn) => symbolTable.put(x, x.getKnownTypes) - case x: Call => symbolTable.put(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) + case x: Call => symbolTable.put(x, (Seq(x.methodFullName) ++ x.dynamicTypeHintFullName ++ x.possibleTypes).toSet) case _ => } override protected def prepopulateSymbolTable(): Unit = { super.prepopulateSymbolTable() - cu.ast.isMethod.foreach(f => symbolTable.put(CallAlias(f.name, Option("this")), Set(f.fullName))) - (cu.ast.isParameter.whereNot(_.nameExact("this")) ++ cu.ast.isMethod.methodReturn).filter(hasTypes).foreach { p => - val resolvedHints = p.getKnownTypes - .map { t => - t.split("\\.").headOption match { - case Some(base) if symbolTable.contains(LocalVar(base)) => - (t, symbolTable.get(LocalVar(base)).map(x => s"$x${t.stripPrefix(base)}")) - case _ => (t, Set(t)) + procedure.ast.isMethod.foreach(f => symbolTable.put(CallAlias(f.name, Option("this")), Set(f.fullName))) + (procedure.ast.isParameter.whereNot(_.nameExact("this")) ++ procedure.ast.isMethod.methodReturn) + .filter(hasTypes) + .foreach { p => + val resolvedHints = p.getKnownTypes + .map { t => + t.split("\\.").headOption match { + case Some(base) if symbolTable.contains(LocalVar(base)) => + (t, symbolTable.get(LocalVar(base)).map(x => s"$x${t.stripPrefix(base)}")) + case _ => (t, Set(t)) + } } + .flatMap { + case (t, ts) if Set(t) == ts => Set(t) + case (_, ts) => ts.map(_.replaceAll("\\.(?!js::program)", pathSep.toString)) + } + p match { + case _: MethodParameterIn => symbolTable.put(p, resolvedHints) + case _: MethodReturn if resolvedHints.sizeIs == 1 => + builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, resolvedHints.head) + case _: MethodReturn => + builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, Defines.Any) + builder.setNodeProperty(p, PropertyNames.POSSIBLE_TYPES, resolvedHints) + case _ => } - .flatMap { - case (t, ts) if Set(t) == ts => Set(t) - case (_, ts) => ts.map(_.replaceAll("\\.(?!js::program)", pathSep.toString)) - } - p match { - case _: MethodParameterIn => symbolTable.put(p, resolvedHints) - case _: MethodReturn if resolvedHints.sizeIs == 1 => - builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, resolvedHints.head) - case _: MethodReturn => - builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, Defines.Any) - builder.setNodeProperty(p, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, resolvedHints) - case _ => } - } } - private lazy val exportedIdentifiers = cu.method - .nameExact(":program") - .ast - .isCall - .nameExact(Operators.assignment) - .filter(_.code.startsWith("exports.*")) - .argument - .isIdentifier - .name - .toSet + private lazy val exportedIdentifiers = + Iterator(procedure) + .repeat(_._astIn)(_.emit.until(_.collectAll[Method].nameExact(":program"))) + .collectAll[Method] + .assignment + .code("export.*") + .argument + .isIdentifier + .name + .toSet override protected def isField(i: Identifier): Boolean = state.isFieldCache.getOrElseUpdate(i.id(), exportedIdentifiers.contains(i.name) || super.isField(i)) @@ -164,7 +175,8 @@ private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBui } override protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = - if (c.name == "require") Set.empty + if (c.name == "require" || c.code.startsWith("require")) Set.empty + else if (c.name.endsWith(".factory")) symbolTable.append(i, c.methodFullName.stripSuffix(".factory")) else super.visitIdentifierAssignedToCall(i, c) override protected def visitIdentifierAssignedToMethodRef( @@ -183,10 +195,10 @@ private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBui override protected def postSetTypeInformation(): Unit = { // often there are "this" identifiers with type hints but this can be set to a type hint if they meet the criteria - cu.ast.isIdentifier + procedure._identifierViaContainsOut .nameExact("this") .where(_.typeFullNameExact(Defines.Any)) - .filterNot(_.dynamicTypeHintFullName.isEmpty) + .filterNot(_.possibleTypes.isEmpty) .foreach(setTypeFromTypeHints) } diff --git a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/TypeRecoveryPassTests.scala b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/JsTypeRecoveryPassTests.scala similarity index 92% rename from joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/TypeRecoveryPassTests.scala rename to joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/JsTypeRecoveryPassTests.scala index 0d2d98f2a90c..35085c26a5b2 100644 --- a/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/TypeRecoveryPassTests.scala +++ b/joern-cli/frontends/jssrc2cpg/src/test/scala/io/joern/jssrc2cpg/passes/JsTypeRecoveryPassTests.scala @@ -4,7 +4,7 @@ import io.joern.jssrc2cpg.testfixtures.DataFlowCodeToCpgSuite import io.joern.x2cpg.passes.frontend.ImportsPass._ import io.shiftleft.semanticcpg.language._ -class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { +class JsTypeRecoveryPassTests extends DataFlowCodeToCpgSuite { "literals declared from built-in types" should { val cpg = code(""" @@ -22,8 +22,8 @@ class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { "resolve 'x' identifier types despite shadowing" in { val List(xOuterScope, xInnerScope) = cpg.identifier.nameExact("x").l - xOuterScope.dynamicTypeHintFullName shouldBe Seq("__ecma.String", "__ecma.Number") - xInnerScope.dynamicTypeHintFullName shouldBe Seq("__ecma.String", "__ecma.Number") + xOuterScope.possibleTypes shouldBe Seq("__ecma.String", "__ecma.Number") + xInnerScope.possibleTypes shouldBe Seq("__ecma.String", "__ecma.Number") } "resolve 'z' types correctly" in { @@ -57,7 +57,7 @@ class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { "resolve correct imports via tag nodes" in { val List(a: UnknownMethod, b: UnknownTypeDecl, x: UnknownMethod, y: UnknownTypeDecl) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked a.fullName shouldBe "slack_sdk:WebClient" b.fullName shouldBe "slack_sdk:WebClient" x.fullName shouldBe "sendgrid:SendGridAPIClient" @@ -141,7 +141,7 @@ class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { "resolve correct imports via tag nodes" in { val List(a: ResolvedMember, b: ResolvedMember, c: ResolvedMember, d: UnknownMethod, e: UnknownTypeDecl) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked a.basePath shouldBe "Foo.ts::program" a.memberName shouldBe "x" b.basePath shouldBe "Foo.ts::program" @@ -167,25 +167,25 @@ class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { "resolve 'foo.x' and 'foo.y' field access primitive types correctly" in { val List(z1, z2) = cpg.file.name(".*Bar.*").ast.isIdentifier.nameExact("z").l z1.typeFullName shouldBe "ANY" - z1.dynamicTypeHintFullName shouldBe Seq("__ecma.Number", "__ecma.String") + z1.possibleTypes.sorted shouldBe Seq("__ecma.Number", "__ecma.String") z2.typeFullName shouldBe "ANY" - z2.dynamicTypeHintFullName shouldBe Seq("__ecma.Number", "__ecma.String") + z2.possibleTypes.sorted shouldBe Seq("__ecma.Number", "__ecma.String") } "resolve 'foo.d' field access object types correctly" in { val List(d1, d2, d3) = cpg.file.name(".*Bar.*").ast.isIdentifier.nameExact("d").l d1.typeFullName shouldBe "flask_sqlalchemy:SQLAlchemy" - d1.dynamicTypeHintFullName shouldBe Seq() + d1.possibleTypes shouldBe empty d2.typeFullName shouldBe "flask_sqlalchemy:SQLAlchemy" - d2.dynamicTypeHintFullName shouldBe Seq() + d2.possibleTypes shouldBe empty d3.typeFullName shouldBe "flask_sqlalchemy:SQLAlchemy" - d3.dynamicTypeHintFullName shouldBe Seq() + d3.possibleTypes shouldBe empty } "resolve a 'createTable' call indirectly from 'foo.d' field access correctly" in { val List(d) = cpg.file.name(".*Bar.*").ast.isCall.name("createTable").l d.methodFullName shouldBe "flask_sqlalchemy:SQLAlchemy:createTable" - d.dynamicTypeHintFullName shouldBe Seq() + d.possibleTypes shouldBe empty d.callee(NoResolve).isExternal.headOption shouldBe Some(true) } @@ -197,7 +197,7 @@ class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { .name("deleteTable") .l d.methodFullName shouldBe "flask_sqlalchemy:SQLAlchemy:deleteTable" - d.dynamicTypeHintFullName shouldBe empty + d.possibleTypes shouldBe empty d.callee(NoResolve).isExternal.headOption shouldBe Some(true) } @@ -229,7 +229,7 @@ class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { ) "resolve correct imports via tag nodes" in { - val List(x: ResolvedMethod) = cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + val List(x: ResolvedMethod) = cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked x.fullName shouldBe "util.js::program:getIncrementalInteger" } @@ -258,7 +258,7 @@ class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { "resolve correct imports via tag nodes" in { val List(x: UnknownMethod, y: UnknownTypeDecl) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked x.fullName shouldBe "googleapis" y.fullName shouldBe "googleapis" } @@ -279,10 +279,15 @@ class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { |""".stripMargin) "resolve correct imports via tag nodes" in { - val List(x: UnknownMethod, y: UnknownTypeDecl, z: UnknownMethod) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + val List(w: UnknownMethod, x: UnknownTypeDecl, y: UnknownMethod, z: UnknownTypeDecl) = + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked + w.alias shouldBe "google" + w.fullName shouldBe "googleapis" + x.alias shouldBe "google" x.fullName shouldBe "googleapis" + y.alias shouldBe "_tmp_0" y.fullName shouldBe "googleapis" + z.alias shouldBe "_tmp_0" z.fullName shouldBe "googleapis" } @@ -381,7 +386,7 @@ class TypeRecoveryPassTests extends DataFlowCodeToCpgSuite { "resolve correct imports via tag nodes" in { val List(a: ResolvedTypeDecl, b: ResolvedMethod, c: ResolvedMethod, d: UnknownMethod, e: UnknownTypeDecl) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked a.fullName shouldBe "foo.js::program" b.fullName shouldBe "foo.js::program:literalFunction" c.fullName shouldBe "foo.js::program:get" diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Kotlin2Cpg.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Kotlin2Cpg.scala index e9daa305d3cb..12314464dbe1 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Kotlin2Cpg.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/Kotlin2Cpg.scala @@ -1,34 +1,27 @@ package io.joern.kotlin2cpg import better.files.File - -import java.nio.file.{Files, Paths} -import org.jetbrains.kotlin.psi.KtFile -import org.slf4j.LoggerFactory - -import scala.util.Try -import scala.jdk.CollectionConverters.{CollectionHasAsScala, EnumerationHasAsScala} -import io.joern.kotlin2cpg.files.SourceFilesPicker -import io.joern.kotlin2cpg.passes.{ - AstCreationPass, - ConfigPass, - DependenciesFromMavenCoordinatesPass, - KotlinTypeHintCallLinker, - KotlinTypeRecoveryPass -} import io.joern.kotlin2cpg.compiler.{CompilerAPI, ErrorLoggingMessageCollector} +import io.joern.kotlin2cpg.files.SourceFilesPicker +import io.joern.kotlin2cpg.interop.JavasrcInterop +import io.joern.kotlin2cpg.jar4import.UsesService +import io.joern.kotlin2cpg.passes.* import io.joern.kotlin2cpg.types.{ContentSourcesPicker, DefaultTypeInfoProvider} import io.joern.kotlin2cpg.utils.PathUtils import io.joern.x2cpg.X2Cpg.withNewEmptyCpg -import io.joern.x2cpg.{SourceFiles, X2CpgFrontend} -import io.joern.x2cpg.passes.frontend.{MetaDataPass, TypeNodePass, XTypeRecoveryConfig} +import io.joern.x2cpg.passes.frontend.{MetaDataPass, TypeNodePass, TypeRecoveryConfig} import io.joern.x2cpg.utils.dependency.{DependencyResolver, DependencyResolverParams, GradleConfigKeys} -import io.joern.kotlin2cpg.interop.JavasrcInterop -import io.joern.kotlin2cpg.jar4import.UsesService +import io.joern.x2cpg.{SourceFiles, X2CpgFrontend} import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.Languages import io.shiftleft.semanticcpg.language.* import io.shiftleft.utils.IOUtils +import org.jetbrains.kotlin.psi.KtFile +import org.slf4j.LoggerFactory + +import java.nio.file.{Files, Paths} +import scala.jdk.CollectionConverters.{CollectionHasAsScala, EnumerationHasAsScala} +import scala.util.Try object Kotlin2Cpg { val language = "KOTLIN" diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeRecoveryPass.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeRecoveryPass.scala index 8ff5be58da2e..376b15fa490d 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeRecoveryPass.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/passes/KotlinTypeRecoveryPass.scala @@ -3,66 +3,79 @@ package io.joern.kotlin2cpg.passes import io.joern.kotlin2cpg.Constants import io.joern.x2cpg.Defines import io.joern.x2cpg.passes.frontend.* +import io.joern.x2cpg.passes.frontend.ImportsPass.ResolvedImport import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.PropertyNames import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.semanticcpg.language.* import overflowdb.BatchedUpdate.DiffGraphBuilder -class KotlinTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) - extends XTypeRecoveryPass[File](cpg, config) { - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = - new KotlinTypeRecovery(cpg, state) +import java.util.concurrent.ExecutorService +import overflowdb.traversal.ImplicitsTmp.toTraversalSugarExt +class KotlinTypeRecoveryPass(cpg: Cpg, config: TypeRecoveryConfig = TypeRecoveryConfig()) + extends XTypeRecoveryPass(cpg, config) { + override protected def generateRecoveryPass(state: TypeRecoveryState, executor: ExecutorService): XTypeRecovery = + new KotlinTypeRecovery(cpg, state, executor) } -private class KotlinTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[File](cpg, state) { - - override def compilationUnit: Iterator[File] = cpg.file.iterator - - override def generateRecoveryForCompilationUnitTask( - unit: File, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[File] = { - val newConfig = state.config.copy(enabledDummyTypes = state.isFinalIteration && state.config.enabledDummyTypes) - new RecoverForKotlinFile(cpg, unit, builder, state.copy(config = newConfig)) - } -} - -private class RecoverForKotlinFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState) - extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) { +private class KotlinTypeRecovery(cpg: Cpg, state: TypeRecoveryState, executor: ExecutorService) + extends XTypeRecovery(cpg, state, executor) { private def kotlinNodeToLocalKey(n: AstNode): Option[LocalKey] = n match { case i: Identifier if i.name == "this" && i.code == "super" => Option(LocalVar("super")) case _ => SBKey.fromNodeToLocalKey(n) } - override protected val symbolTable = new SymbolTable[LocalKey](kotlinNodeToLocalKey) - - override protected def importNodes: Iterator[Import] = cu.ast.isImport - override protected def visitImport(i: Import): Unit = { - - val alias = i.importedAs.getOrElse("") - val fullName = i.importedEntity.getOrElse("") - if (alias != Constants.wildcardImportName) { - symbolTable.append(CallAlias(alias), fullName) - symbolTable.append(LocalVar(alias), fullName) + override protected val initialSymbolTable = new SymbolTable[LocalKey](kotlinNodeToLocalKey) + + override protected def recoverTypesForProcedure( + cpg: Cpg, + procedure: Method, + initialSymbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState + ): RecoverTypesForProcedure = new RecoverForKotlinProcedure(cpg, procedure, initialSymbolTable, builder, state) + + override protected def importNodes(cu: File): List[ResolvedImport] = + cu.namespaceBlock.flatMap(_.astOut).collectAll[Import].flatMap(visitImport).l + + // Kotlin has a much simpler import structure that doesn't need resolution + override protected def visitImport(i: Import): Iterator[ImportsPass.ResolvedImport] = { + for { + alias <- i.importedAs + fullName <- i.importedEntity + } { + if (alias != Constants.wildcardImportName) { + initialSymbolTable.append(CallAlias(alias, Option("this")), fullName) + initialSymbolTable.append(LocalVar(alias), fullName) + } } + Iterator.empty } - override protected def isConstructor(c: Call): Boolean = isConstructor(c.name) - - override protected def isConstructor(name: String): Boolean = !name.isBlank && name.charAt(0).isUpper - override protected def postVisitImports(): Unit = { - symbolTable.view.foreach { case (k, ts) => + initialSymbolTable.view.foreach { case (k, ts) => val tss = ts.filterNot(_.startsWith(Defines.UnresolvedNamespace)) if (tss.isEmpty) - symbolTable.remove(k) + initialSymbolTable.remove(k) else - symbolTable.put(k, tss) + initialSymbolTable.put(k, tss) } } +} + +private class RecoverForKotlinProcedure( + cpg: Cpg, + procedure: Method, + symbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState +) extends RecoverTypesForProcedure(cpg, procedure, symbolTable, builder, state) { + + override protected def isConstructor(c: Call): Boolean = isConstructor(c.name) + + override protected def isConstructor(name: String): Boolean = !name.isBlank && name.charAt(0).isUpper + // There seems to be issues with inferring these, often due to situations where super and this are confused on name // and code properties. override protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = if (i.name != "this") { @@ -76,7 +89,7 @@ private class RecoverForKotlinFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder case t if t.endsWith(c.signature) => t case t => s"$t:${c.signature}" } - builder.setNodeProperty(c, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, signedTypes) + builder.setNodeProperty(c, PropertyNames.POSSIBLE_TYPES, signedTypes) } } diff --git a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeNodeTests.scala b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeNodeTests.scala index 13e543aea12d..2ecb27453910 100644 --- a/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeNodeTests.scala +++ b/joern-cli/frontends/php2cpg/src/test/scala/io/joern/php2cpg/querying/TypeNodeTests.scala @@ -27,7 +27,6 @@ class TypeNodeTests extends PhpCode2CpgFixture { val cpg = code(""" cpg.typeDecl.fullNameExact(x) ++ cpg.method.fullNameExact(x)).collect { case x: Method => ResolvedMethod(x.fullName, alias) - case x: TypeDecl => ResolvedTypeDecl(x.fullName) + case x: TypeDecl => ResolvedTypeDecl(x.fullName, alias) } if (resolvedEntities.isEmpty) { - traversal.filterNot(_.contains("__init__.py")).map(x => UnknownImport(x)) + traversal.filterNot(_.contains("__init__.py")).map(x => UnknownImport(x, alias)) } else { resolvedEntities } @@ -117,9 +117,10 @@ class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { ResolvedMethod(t.method.nameExact(m.name).fullName.head, alias) case (t, m) if t.astSiblings.isMethod.fullNameExact(t.fullName).ast.isTypeDecl.nameExact(m.name).nonEmpty => ResolvedTypeDecl( - t.astSiblings.isMethod.fullNameExact(t.fullName).ast.isTypeDecl.nameExact(m.name).fullName.head + t.astSiblings.isMethod.fullNameExact(t.fullName).ast.isTypeDecl.nameExact(m.name).fullName.head, + alias ) - case (t, m) => ResolvedMember(t.fullName, m.name) + case (t, m) => ResolvedMember(t.fullName, m.name, alias) } case _ => // Case 4: Import from module using alias, e.g. import bar from foo as faz @@ -136,13 +137,13 @@ class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg) { }).flatMap { // If we import the constructor, we also import the type case x: ResolvedMethod if isMaybeConstructor => - Seq(ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias), ResolvedTypeDecl(x.fullName)) + Seq(ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias), ResolvedTypeDecl(x.fullName, alias)) // If we import the type, we also import the constructor case x: ResolvedTypeDecl if isMaybeConstructor => Seq(x, ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias)) // If we can determine the import is a constructor, then it is likely not a member case x: UnknownImport if isMaybeConstructor => - Seq(UnknownMethod(Seq(x.path, "__init__").mkString(pathSep), alias), UnknownTypeDecl(x.path)) + Seq(UnknownMethod(Seq(x.path, "__init__").mkString(pathSep), alias), UnknownTypeDecl(x.path, alias)) case x => Seq(x) }.toSet } diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeHintCallLinker.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeHintCallLinker.scala index 61593844ba78..499d3e8030b9 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeHintCallLinker.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeHintCallLinker.scala @@ -23,7 +23,9 @@ class PythonTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg) { } else if (methodNames.sizeIs > 1) { val nonDummyMethodNames = methodNames.filterNot(x => isDummyType(x) || x.startsWith(PythonAstVisitor.builtinPrefix + "None")) - super.setCallees(call, nonDummyMethodNames, builder) + + if (nonDummyMethodNames.nonEmpty) super.setCallees(call, nonDummyMethodNames, builder) + else super.setCallees(call, methodNames, builder) } } diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala index a84a423baafa..0f1860559c62 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala @@ -1,39 +1,42 @@ package io.joern.pysrc2cpg -import io.joern.x2cpg.passes.frontend._ +import io.joern.x2cpg.passes.frontend.* +import io.joern.x2cpg.passes.frontend.ImportsPass.* import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.nodes._ +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.operatorextension.OpNodes import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess import overflowdb.BatchedUpdate.DiffGraphBuilder -class PythonTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) - extends XTypeRecoveryPass[File](cpg, config) { +import java.util.concurrent.ExecutorService - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = - new PythonTypeRecovery(cpg, state) -} +class PythonTypeRecoveryPass(cpg: Cpg, config: TypeRecoveryConfig = TypeRecoveryConfig()) + extends XTypeRecoveryPass(cpg, config) { -private class PythonTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[File](cpg, state) { + override protected def generateRecoveryPass(state: TypeRecoveryState, executor: ExecutorService): XTypeRecovery = + new PythonTypeRecovery(cpg, state, executor) +} - override def compilationUnit: Iterator[File] = cpg.file.iterator +private class PythonTypeRecovery(cpg: Cpg, state: TypeRecoveryState, executor: ExecutorService) + extends XTypeRecovery(cpg, state, executor) { - override def generateRecoveryForCompilationUnitTask( - unit: File, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[File] = { - val newConfig = state.config.copy(enabledDummyTypes = state.isFinalIteration && state.config.enabledDummyTypes) - new RecoverForPythonFile(cpg, unit, builder, state.copy(config = newConfig)) - } + import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt -} + override val initialSymbolTable: SymbolTable[LocalKey] = SymbolTable[LocalKey](fromNodeToLocalPythonKey) -/** Performs type recovery from the root of a compilation unit level - */ -private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState) - extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) { + override def loadImports(i: ResolvedImport, symbolTable: SymbolTable[LocalKey]): Unit = i match { + case ResolvedMember(basePath, memberName, alias, _) => + val memberTypes = cpg.typeDecl + .fullNameExact(basePath) + .member + .nameExact(memberName) + .flatMap(_.getKnownTypes) + .toSet + symbolTable.put(LocalVar(alias), memberTypes) + case _ => super.loadImports(i, symbolTable) + } /** Replaces the `this` prefix with the Pythonic `self` prefix for instance methods of functions local to this * compilation unit. @@ -44,35 +47,28 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder case _ => SBKey.fromNodeToLocalKey(node) } - override val symbolTable: SymbolTable[LocalKey] = new SymbolTable[LocalKey](fromNodeToLocalPythonKey) - - override def visitImport(i: Import): Unit = { - if (i.importedAs.isDefined && i.importedEntity.isDefined) { - import io.joern.x2cpg.passes.frontend.ImportsPass._ - - val entityName = i.importedAs.get - i.call.tag.flatMap(ResolvedImport.tagToResolvedImport).foreach { - case ResolvedMethod(fullName, alias, receiver, _) => symbolTable.put(CallAlias(alias, receiver), fullName) - case ResolvedTypeDecl(fullName, _) => symbolTable.put(LocalVar(entityName), fullName) - case ResolvedMember(basePath, memberName, _) => - val memberTypes = cpg.typeDecl - .fullNameExact(basePath) - .member - .nameExact(memberName) - .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) - .filterNot(_ == "ANY") - .toSet - symbolTable.put(LocalVar(entityName), memberTypes) - case UnknownMethod(fullName, alias, receiver, _) => - symbolTable.put(CallAlias(alias, receiver), fullName) - case UnknownTypeDecl(fullName, _) => - symbolTable.put(LocalVar(entityName), fullName) - case UnknownImport(path, _) => - symbolTable.put(CallAlias(entityName), path) - symbolTable.put(LocalVar(entityName), path) - } - } - } + override protected def recoverTypesForProcedure( + cpg: Cpg, + procedure: Method, + initialSymbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState + ): RecoverTypesForProcedure = + RecoverForPythonProcedure(cpg, procedure, initialSymbolTable, builder, state) + +} + +/** Performs type recovery from the root of a compilation unit level + */ +private class RecoverForPythonProcedure( + cpg: Cpg, + procedure: Method, + symbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState +) extends RecoverTypesForProcedure(cpg, procedure, symbolTable, builder, state) { + + import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt override def visitAssignments(a: OpNodes.Assignment): Set[String] = { a.argumentOut.l match { @@ -111,6 +107,16 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder } } + override def visitStatementsInBlock(b: Block, assignmentTarget: Option[Identifier]): Set[String] = { + if (b.inAssignment.nonEmpty && b.expressionDown.assignment.argument(1).fieldAccess.code(".*").nonEmpty) { + super.visitStatementsInBlock(b, assignmentTarget) + // Shortcut the actual value of the module access + visitAssignmentArguments(List(b.inAssignment.target.head, b.expressionDown.assignment.head.source)) + } else { + super.visitStatementsInBlock(b, assignmentTarget) + } + } + override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"${pathSep}__init__")) associateTypes(i, constructorPaths) @@ -129,8 +135,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder fa.astChildren.l match { case List(base: Identifier, fi: FieldIdentifier) if base.name.equals("self") && fieldParents.nonEmpty => val referencedFields = cpg.typeDecl.fullNameExact(fieldParents.toSeq: _*).member.nameExact(fi.canonicalName) - val globalTypes = - referencedFields.flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName).filterNot(_ == Constants.ANY).toSet + val globalTypes = referencedFields.flatMap(_.getKnownTypes).toSet associateTypes(i, globalTypes) case _ => super.visitIdentifierAssignedToFieldLoad(i, fa) } @@ -182,8 +187,9 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder } } - override protected def postSetTypeInformation(): Unit = - cu.typeDecl + override protected def postSetTypeInformation(): Unit = { + super.postSetTypeInformation() + procedure.typeDecl .map(t => t -> t.inheritsFromTypeFullName.partition(itf => symbolTable.contains(LocalVar(itf)))) .foreach { case (t, (identifierTypes, otherTypes)) => val existingTypes = (identifierTypes ++ otherTypes).distinct @@ -193,9 +199,10 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder builder.setNodeProperty(t, PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, resolvedTypes) } } + } override def prepopulateSymbolTable(): Unit = { - cu.ast.isMethodRef.where(_.astSiblings.isIdentifier.nameExact("classmethod")).referencedMethod.foreach { + procedure.ast.isMethodRef.where(_.astSiblings.isIdentifier.nameExact("classmethod")).referencedMethod.foreach { classMethod => classMethod.parameter .nameExact("cls") @@ -203,7 +210,7 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder val clsPath = classMethod.typeDecl.fullName.toSet symbolTable.put(LocalVar(cls.name), clsPath) if (cls.typeFullName == "ANY") - builder.setNodeProperty(cls, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, clsPath.toSeq) + builder.setNodeProperty(cls, PropertyNames.POSSIBLE_TYPES, clsPath.toSeq) } } super.prepopulateSymbolTable() @@ -216,4 +223,27 @@ private class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder .headOption .getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, rec)) + override protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = + super.storeCallTypeInfo(c, types.filterNot(_.startsWith("__builtin.None"))) + + override protected def persistType(x: StoredNode, types: Set[String]): Unit = x match { + case _: Call => super.persistType(x, types.filterNot(_.startsWith("__builtin.None"))) + case _ => super.persistType(x, types.filterNot(_.matches("__builtin\\.None.+"))) + } + + override protected def setTypes(n: StoredNode, types: Seq[String]): Unit = n match { + case _: Call => super.setTypes(n, types.filterNot(_.startsWith("__builtin.None"))) + case _ => super.setTypes(n, types.filterNot(_.matches("__builtin\\.None.+"))) + } + + override protected def handlePotentialFunctionPointer( + funcPtr: Expression, + baseTypes: Set[String], + funcName: String, + baseName: Option[String] + ): Unit = { + if (funcName != "") + super.handlePotentialFunctionPointer(funcPtr, baseTypes, funcName, baseName) + } + } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/PySrc2CpgFixture.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/PySrc2CpgFixture.scala index 16ac4b3a4d86..1c9ae9d332a3 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/PySrc2CpgFixture.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/PySrc2CpgFixture.scala @@ -3,9 +3,10 @@ package io.joern.pysrc2cpg import io.joern.dataflowengineoss.layers.dataflows.{OssDataFlow, OssDataFlowOptions} import io.joern.dataflowengineoss.queryengine.EngineContext import io.joern.dataflowengineoss.semanticsloader.FlowSemantic -import io.joern.x2cpg.X2Cpg +import io.joern.x2cpg.{X2Cpg, X2CpgConfig} import io.joern.x2cpg.passes.base.AstLinkerPass import io.joern.x2cpg.passes.callgraph.NaiveCallLinker +import io.joern.x2cpg.passes.frontend.{TypeRecoveryConfig, TypeRecoveryParserConfig} import io.joern.x2cpg.testfixtures.{Code2CpgFixture, LanguageFrontend, TestCpg} import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.semanticcpg.language.{ICallResolver, NoResolve} @@ -39,7 +40,14 @@ class PySrcTestCpg extends TestCpg with PythonFrontend { new ImportResolverPass(this).createAndApply() new PythonInheritanceNamePass(this).createAndApply() new DynamicTypeHintFullNamePass(this).createAndApply() - new PythonTypeRecoveryPass(this).createAndApply() + getConfig() match + case Some(config: X2CpgConfig[_] with TypeRecoveryParserConfig[_]) => + new PythonTypeRecoveryPass( + this, + TypeRecoveryConfig(config.typePropagationIterations, !config.disableDummyTypes) + ).createAndApply() + case _ => new PythonTypeRecoveryPass(this).createAndApply() + new PythonTypeHintCallLinker(this).createAndApply() new NaiveCallLinker(this).createAndApply() diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/PyTypeRecoveryPassTests.scala similarity index 94% rename from joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala rename to joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/PyTypeRecoveryPassTests.scala index 4416eb4d4572..3ab7acf85ba5 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/PyTypeRecoveryPassTests.scala @@ -1,13 +1,13 @@ package io.joern.pysrc2cpg.passes -import io.joern.pysrc2cpg.PySrc2CpgFixture +import io.joern.pysrc2cpg.{Py2CpgOnFileSystemConfig, PySrc2CpgFixture} import io.joern.x2cpg.passes.frontend.ImportsPass.* import io.joern.x2cpg.passes.frontend.{ImportsPass, XTypeHintCallLinker} import io.shiftleft.semanticcpg.language.* import java.io.File import scala.collection.immutable.Seq -class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { +class PyTypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "literals declared from built-in types" should { lazy val cpg = code(""" @@ -25,15 +25,15 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve 'x' identifier types despite shadowing" in { val List(xOuterScope, xInnerScope) = cpg.identifier("x").take(2).l - xOuterScope.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str") - xInnerScope.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str") + xOuterScope.possibleTypes.sorted shouldBe Seq("__builtin.int", "__builtin.str") + xInnerScope.possibleTypes.sorted shouldBe Seq("__builtin.int", "__builtin.str") } "resolve 'y' and 'z' identifier collection types" in { val List(zDict, zList, zTuple) = cpg.identifier("z").take(3).l - zDict.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple") - zList.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple") - zTuple.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple") + zDict.possibleTypes.sorted shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple") + zList.possibleTypes.sorted shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple") + zTuple.possibleTypes.sorted shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple") } "resolve 'z' identifier calls conservatively" in { @@ -41,7 +41,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { zAppend.methodFullName shouldBe "" // Since we don't have method nodes with this full name, this should belong to the call linker namespace zAppend.callee.astParentFullName.headOption shouldBe Some(XTypeHintCallLinker.namespace) - zAppend.dynamicTypeHintFullName shouldBe Seq( + zAppend.possibleTypes.sorted shouldBe Seq( "__builtin.dict.append", "__builtin.list.append", "__builtin.tuple.append" @@ -70,7 +70,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { webClientT: UnknownTypeDecl, sendGridM: UnknownMethod, sendGridT: UnknownTypeDecl - ) = cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + ) = cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked webClientM.fullName shouldBe "slack_sdk.py:.WebClient.__init__" webClientT.fullName shouldBe "slack_sdk.py:.WebClient" sendGridM.fullName shouldBe "sendgrid.py:.SendGridAPIClient.__init__" @@ -183,7 +183,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "conservatively present either option when an imported function uses the same name as a builtin" in { val Some(absCall) = cpg.call("abs").headOption: @unchecked - absCall.dynamicTypeHintFullName shouldBe Seq("foo.py:.abs", "__builtin.abs") + absCall.possibleTypes shouldBe Seq("foo.py:.abs", "__builtin.abs") } } @@ -216,11 +216,11 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve correct imports via tag nodes" in { val List(foo1: UnknownMethod, foo2: UnknownTypeDecl) = - cpg.file(".*foo.py").ast.isCall.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.file(".*foo.py").ast.isCall.where(_.referencedImports).toResolvedImport.toList: @unchecked foo1.fullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.__init__" foo2.fullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy" val List(bar1: ResolvedTypeDecl, bar2: ResolvedMethod) = - cpg.file(".*bar.py").ast.isCall.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.file(".*bar.py").ast.isCall.where(_.referencedImports).toResolvedImport.toList: @unchecked bar1.fullName shouldBe "foo.py:" bar2.fullName shouldBe "foo.py:" } @@ -240,9 +240,9 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { .name("z") .l z1.typeFullName shouldBe "ANY" - z1.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str") + z1.possibleTypes.sorted shouldBe Seq("__builtin.int", "__builtin.str") z2.typeFullName shouldBe "ANY" - z2.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str") + z2.possibleTypes.sorted shouldBe Seq("__builtin.int", "__builtin.str") } "resolve 'foo.d' field access object types correctly" in { @@ -253,7 +253,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { .name("d") .headOption: @unchecked d.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy" - d.dynamicTypeHintFullName shouldBe Seq() + d.possibleTypes shouldBe empty } "resolve a 'createTable' call indirectly from 'foo.d' field access correctly" in { @@ -264,7 +264,6 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { .name("createTable") .l d.methodFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.createTable" - d.dynamicTypeHintFullName shouldBe Seq() d.callee(NoResolve).isExternal.headOption shouldBe Some(true) } @@ -277,7 +276,6 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { .l d.methodFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.deleteTable" - d.dynamicTypeHintFullName shouldBe Seq() d.callee(NoResolve).isExternal.headOption shouldBe Some(true) } @@ -309,7 +307,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve correct imports via tag nodes" in { val List(a: ResolvedTypeDecl, b: ResolvedMethod, c: UnknownImport, d: ResolvedMember) = - cpg.file(".*UserController.py").ast.isCall.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.file(".*UserController.py").ast.isCall.where(_.referencedImports).toResolvedImport.toList: @unchecked a.fullName shouldBe "app.py:" b.fullName shouldBe "app.py:" c.path shouldBe "flask.py:.jsonify" @@ -317,7 +315,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { d.memberName shouldBe "db" val List(sqlAlchemyM: UnknownMethod, sqlAlchemyT: UnknownTypeDecl) = - cpg.file(".*app.py").ast.isCall.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.file(".*app.py").ast.isCall.where(_.referencedImports).toResolvedImport.toList: @unchecked sqlAlchemyM.fullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.__init__" sqlAlchemyT.fullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy" } @@ -355,7 +353,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { |""".stripMargin).cpg "resolve correct imports via tag nodes" in { - val List(logging: UnknownImport) = cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + val List(logging: UnknownImport) = cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked logging.path shouldBe "logging.py:" } @@ -379,7 +377,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve correct imports via tag nodes" in { val List(error: UnknownImport, request: UnknownImport) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked error.path shouldBe "urllib.py:.error" request.path shouldBe "urllib.py:.request" } @@ -435,11 +433,11 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { |# dummy file to trigger isExternal = false on methods that are imported from here |""".stripMargin, "pymongo.py" - ).cpg + ) "resolve correct imports via tag nodes" in { val List(a: ResolvedTypeDecl, b: ResolvedMethod, c: UnknownMethod, d: UnknownTypeDecl, e: UnknownImport) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked a.fullName shouldBe "MongoConnection.py:.MongoConnection" b.fullName shouldBe "MongoConnection.py:.MongoConnection.__init__" @@ -450,16 +448,9 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "recover a potential type for `self.collection` using the assignment at `get_collection` as a type hint" in { val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption: @unchecked - selfFindFound.dynamicTypeHintFullName shouldBe Seq( - "__builtin.None.find_one", - "pymongo.py:.MongoClient.__init__...find_one" - ) + selfFindFound.methodFullName shouldBe "pymongo.py:.MongoClient.__init__....find_one" } - "correctly determine that, despite being unable to resolve the correct method full name, that it is an internal method" in { - val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption: @unchecked - selfFindFound.callee.isExternal.toSeq shouldBe Seq(true, true) - } } "a recursive field access based call type" should { @@ -568,7 +559,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { sessionM: ResolvedMethod, sqlSessionM: UnknownMethod, sqlSessionT: UnknownTypeDecl - ) = cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + ) = cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked sessionT.fullName shouldBe Seq("data", "db_session.py:").mkString(File.separator) sessionM.fullName shouldBe Seq("data", "db_session.py:").mkString(File.separator) sqlSessionM.fullName shouldBe Seq("sqlalchemy", "orm.py:.Session.__init__").mkString(File.separator) @@ -655,7 +646,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve correct imports via tag nodes" in { val List(sqlSessionM: UnknownMethod, sqlSessionT: UnknownTypeDecl, db: ResolvedMember) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked sqlSessionM.fullName shouldBe Seq("flask_sqlalchemy.py:.SQLAlchemy.__init__").mkString(File.separator) sqlSessionT.fullName shouldBe Seq("flask_sqlalchemy.py:.SQLAlchemy").mkString(File.separator) db.basePath shouldBe Seq("api", "__init__.py:").mkString(File.separator) @@ -936,7 +927,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve correct imports via tag nodes" in { val List(djangoModels: UnknownImport, profileT: ResolvedTypeDecl, profileM: ResolvedMethod) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked djangoModels.path shouldBe Seq("django", "db.py:.models").mkString(File.separator) profileT.fullName shouldBe "models.py:.Profile" profileM.fullName shouldBe "models.py:.Profile.__init__" @@ -974,11 +965,11 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { | c.getBotoClient().getS3Object() |""".stripMargin, "impl.py" - ) + ).withConfig(Py2CpgOnFileSystemConfig().withTypePropagationIterations(3)) "resolve correct imports via tag nodes" in { val List(connectorT: ResolvedTypeDecl, connectorM: ResolvedMethod) = - cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked connectorT.fullName shouldBe Seq("lib", "connector.py:.Connector").mkString(File.separator) connectorM.fullName shouldBe Seq("lib", "connector.py:.Connector.__init__").mkString(File.separator) } @@ -1050,9 +1041,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "be able to handle a simple call off an alias" in { val Some(redisGet) = cpg.call.nameExact("publish_json").headOption: @unchecked - redisGet.methodFullName shouldBe Seq("db", "redis.py:.RedisDB.get_redis.publish_json").mkString( - File.separator - ) + redisGet.methodFullName shouldBe "aioredis.py:.Redis.publish_json" } } @@ -1104,6 +1093,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { |""".stripMargin, Seq("oauth2", "__init__.py").mkString(File.separator) ) + .withConfig(Py2CpgOnFileSystemConfig().withTypePropagationIterations(4)) "instantiate the return value correctly under `from_string`" in { val Some(token) = cpg.method("from_string").ast.isIdentifier.nameExact("token").headOption: @unchecked diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index c3f9852e185f..2f5bec99d9e1 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -567,6 +567,8 @@ class AstCreator( val parenAst = astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses()) val callNode = methodIdAst.head.nodes.filter(_.isInstanceOf[NewCall]).head.asInstanceOf[NewCall] callNode.name(resolveAlias(callNode.name)) + // TODO: We need a receiver of some kind + val receiverAst = Option(Ast(createThisIdentifier(ctx))) if (ctx.block() != null) { val isYieldMethod = if (callNode.name.endsWith(YIELD_SUFFIX)) { @@ -578,13 +580,13 @@ class AstCreator( if (isYieldMethod) { val methAst = astForBlock(ctx.block(), Some(callNode.name)) blockMethods.addOne(methAst) - Seq(callAst(callNode, parenAst)) + Seq(callAst(callNode, parenAst, receiverAst)) } else { val blockAst = Seq(astForBlock(ctx.block())) - Seq(callAst(callNode, parenAst ++ blockAst)) + Seq(callAst(callNode, parenAst ++ blockAst, receiverAst)) } } else - Seq(callAst(callNode, parenAst)) + Seq(callAst(callNode, parenAst, receiverAst)) } def astForCallNode(ctx: ParserRuleContext, code: String, isYieldBlock: Boolean = false): Ast = { @@ -596,8 +598,9 @@ class AstCreator( if (isBuiltin(calleeName)) builtInCallNames.add(calleeName) calleeName } - - callAst(callNode(ctx, code, name, DynamicCallUnknownFullName, DispatchTypes.STATIC_DISPATCH)) + // TODO: We need some kind of receiver, so I'm adding this + val thisNode = Option(Ast(createThisIdentifier(ctx))) + callAst(callNode(ctx, code, name, DynamicCallUnknownFullName, DispatchTypes.STATIC_DISPATCH), base = thisNode) } private def astForMethodOnlyIdentifier(ctx: MethodOnlyIdentifierContext): Seq[Ast] = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 432b4944f4af..e7de69f0c50a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -249,7 +249,8 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V val methodIdentifierAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), text(ctx)) methodIdentifierAsts.headOption.foreach(methodNameAsIdentifierStack.push) val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - + // TODO: We need a receiver of some kind + val receiverAst = Option(Ast(createThisIdentifier(ctx))) /* get args without the method def in it */ val argAstsWithoutMethods = argsAsts.filterNot(_.root.exists(_.isInstanceOf[NewMethod])) @@ -282,9 +283,9 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V resolveRelativePath(filename, argsAsts, callNode) } else if (prefixMethods.contains(callNode.name)) { /* we remove the method definition AST from argument and add its corresponding identifier form */ - Seq(callAst(callNode, argAstsWithoutMethods ++ methodToIdentifierAsts)) + Seq(callAst(callNode, argAstsWithoutMethods ++ methodToIdentifierAsts, receiverAst)) } else { - Seq(callAst(callNode, argsAsts)) + Seq(callAst(callNode, argAstsWithoutMethods, receiverAst)) } } else { argsAsts diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImportResolverPass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImportResolverPass.scala index 3f2754e27c31..8ccbcf85c297 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImportResolverPass.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/ImportResolverPass.scala @@ -23,10 +23,10 @@ class ImportResolverPass(cpg: Cpg, packageTableInfo: PackageTable) extends XImpo diffGraph: DiffGraphBuilder ): Unit = { - resolveEntities(importedEntity, importCall, fileName).foreach(x => resolvedImportToTag(x, importCall, diffGraph)) + resolveEntities(importedEntity, importedAs, fileName).foreach(x => resolvedImportToTag(x, importCall, diffGraph)) } - private def resolveEntities(expEntity: String, importCall: Call, fileName: String): Set[ResolvedImport] = { + private def resolveEntities(expEntity: String, alias: String, fileName: String): Set[ResolvedImport] = { // TODO /* Currently we are considering only case where exposed module are Classes, @@ -54,13 +54,13 @@ class ImportResolverPass(cpg: Cpg, packageTableInfo: PackageTable) extends XImpo .flatMap { typeDeclModel => Seq( ResolvedMethod(s"${typeDeclModel.fullName}.${XDefines.ConstructorMethodName}", "new"), - ResolvedTypeDecl(typeDeclModel.fullName) + ResolvedTypeDecl(typeDeclModel.fullName, alias) ) } .distinct val importNodesFromModule = packageTableInfo.getModule(expEntity).flatMap { moduleModel => - Seq(ResolvedTypeDecl(moduleModel.fullName)) + Seq(ResolvedTypeDecl(moduleModel.fullName, alias)) } (importNodesFromTypeDecl ++ importNodesFromModule).toSet } else { @@ -69,14 +69,17 @@ class ImportResolverPass(cpg: Cpg, packageTableInfo: PackageTable) extends XImpo .where(_.file.name(filePattern)) .fullName .flatMap(fullName => - Seq(ResolvedTypeDecl(fullName), ResolvedMethod(s"$fullName.${XDefines.ConstructorMethodName}", "new")) + Seq( + ResolvedTypeDecl(fullName, alias), + ResolvedMethod(s"$fullName.${XDefines.ConstructorMethodName}", "new") + ) ) .toSet val resolvedModules = cpg.namespaceBlock .whereNot(_.nameExact("")) .where(_.file.name(filePattern)) - .flatMap(module => Seq(ResolvedTypeDecl(module.fullName))) + .flatMap(module => Seq(ResolvedTypeDecl(module.fullName, alias))) .toSet // Expose methods which are directly present in a file, without any module, TypeDecl diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryPass.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryPass.scala index b75790fb11f6..9ee6ef180f78 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryPass.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryPass.scala @@ -1,33 +1,47 @@ package io.joern.rubysrc2cpg.passes +import io.joern.x2cpg.Defines as XDefines import io.joern.x2cpg.passes.frontend.* +import io.joern.x2cpg.passes.frontend.ImportsPass.ResolvedTypeDecl import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* import overflowdb.BatchedUpdate.DiffGraphBuilder -import io.joern.x2cpg.Defines as XDefines -class RubyTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) - extends XTypeRecoveryPass[File](cpg, config) { - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = - new RubyTypeRecovery(cpg, state) +import java.util.concurrent.ExecutorService + +class RubyTypeRecoveryPass(cpg: Cpg, config: TypeRecoveryConfig = TypeRecoveryConfig()) + extends XTypeRecoveryPass(cpg, config) { + override protected def generateRecoveryPass(state: TypeRecoveryState, executor: ExecutorService): XTypeRecovery = + new RubyTypeRecovery(cpg, state, executor) } -private class RubyTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[File](cpg, state) { +private class RubyTypeRecovery(cpg: Cpg, state: TypeRecoveryState, executor: ExecutorService) + extends XTypeRecovery(cpg, state, executor) { - override def compilationUnit: Iterator[File] = cpg.file.iterator + override protected def recoverTypesForProcedure( + cpg: Cpg, + procedure: Method, + initialSymbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState + ): RecoverTypesForProcedure = RecoverForRubyFile(cpg, procedure, initialSymbolTable, builder, state) - override def generateRecoveryForCompilationUnitTask( - unit: File, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[File] = { - val newConfig = state.config.copy(enabledDummyTypes = state.isFinalIteration && state.config.enabledDummyTypes) - new RecoverForRubyFile(cpg, unit, builder, state.copy(config = newConfig)) - } + override protected def loadImports(i: ImportsPass.ResolvedImport, symbolTable: SymbolTable[LocalKey]): Unit = + i match { + case ResolvedTypeDecl(fullName, alias, _) => + symbolTable.append(LocalVar(fullName.split("\\.").lastOption.getOrElse(alias)), fullName) + case _ => super.loadImports(i, symbolTable) + } } -private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, state: XTypeRecoveryState) - extends RecoverForXCompilationUnit[File](cpg, cu, builder, state) { +private class RecoverForRubyFile( + cpg: Cpg, + procedure: Method, + symbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState +) extends RecoverTypesForProcedure(cpg, procedure, symbolTable, builder, state) { /** A heuristic method to determine if a call is a constructor or not. */ @@ -40,17 +54,6 @@ private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, override protected def isConstructor(name: String): Boolean = !name.isBlank && (name == "new" || name == XDefines.ConstructorMethodName) - override def visitImport(i: Import): Unit = for { - resolvedImport <- i.call.tag - alias <- i.importedAs - } { - import io.joern.x2cpg.passes.frontend.ImportsPass.* - ResolvedImport.tagToResolvedImport(resolvedImport).foreach { - case ResolvedTypeDecl(fullName, _) => - symbolTable.append(LocalVar(fullName.split("\\.").lastOption.getOrElse(alias)), fullName) - case _ => super.visitImport(i) - } - } override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { def isMatching(cName: String, code: String) = { @@ -67,9 +70,9 @@ private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, override def methodReturnValues(methodFullNames: Seq[String]): Set[String] = { // Check if we have a corresponding member to resolve type val memberTypes = methodFullNames.flatMap { fullName => - val memberName = fullName.split("\\.").lastOption + val memberName = fullName.split(pathSep).lastOption if (memberName.isDefined) { - val typeDeclFullName = fullName.stripSuffix(s".${memberName.get}") + val typeDeclFullName = fullName.stripSuffix(s"$pathSep${memberName.get}") cpg.typeDecl.fullName(typeDeclFullName).member.nameExact(memberName.get).typeFullName.l } else List.empty @@ -108,7 +111,7 @@ private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, } } - protected def setCallMethodFullNameFromBaseScopeResolution(c: Call): Set[String] = { + private def setCallMethodFullNameFromBaseScopeResolution(c: Call): Set[String] = { val recTypes = c.argument.headOption .map { case x: Call if x.name.equals(".scopeResolution") => @@ -119,11 +122,26 @@ private class RecoverForRubyFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, symbolTable.append(c, callTypes) } - override protected def visitIdentifierAssignedToTypeRef(i: Identifier, t: TypeRef, rec: Option[String]): Set[String] = + override protected def visitIdentifierAssignedToTypeRef( + i: Identifier, + t: TypeRef, + rec: Option[String] + ): Set[String] = { + val receiver = rec match + case Some(x) => Option(x) + case None => Option("this") t.typ.referencedTypeDecl .map(_.fullName.stripSuffix("")) - .map(td => symbolTable.append(CallAlias(i.name, rec), Set(td))) + .map(td => symbolTable.append(CallAlias(i.name, receiver), Set(td))) .headOption - .getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, rec)) + .getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, receiver)) + } + + override protected def visitIdentifierAssignedToMethodRef( + i: Identifier, + m: MethodRef, + rec: Option[String] = None + ): Set[String] = + super.visitIdentifierAssignedToMethodRef(i, m, Option("this")) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/DataFlowTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/DataFlowTests.scala index 9ccb565028dc..36467cf30883 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/DataFlowTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/DataFlowTests.scala @@ -35,19 +35,18 @@ class DataFlowTests extends RubyCode2CpgFixture(withPostProcessing = true, withD "Flow via call" should { val cpg = code(""" |def print(content) - |puts content + | puts content |end | |def main - |n = 1 - |print( n ) + | n = 1 + | print(n) |end |""".stripMargin) "be found" in { - implicit val resolver: ICallResolver = NoResolve - val src = cpg.identifier.name("n").where(_.inCall.name("print")).l - val sink = cpg.method.name("puts").callIn.argument(1).l + val src = cpg.identifier.name("n").where(_.inCall.name("print")).l + val sink = cpg.method.name("puts").callIn.argument(1).l sink.reachableByFlows(src).size shouldBe 1 } } @@ -1008,7 +1007,7 @@ class DataFlowTests extends RubyCode2CpgFixture(withPostProcessing = true, withD "find flows to the sink" in { val source = cpg.identifier.name("x").l val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 + sink.reachableByFlows(source).size shouldBe 8 } } @@ -1031,7 +1030,7 @@ class DataFlowTests extends RubyCode2CpgFixture(withPostProcessing = true, withD "find flows to the sink" in { val source = cpg.identifier.name("x").l val sink = cpg.call.name("puts").l - sink.reachableByFlows(source).size shouldBe 2 + sink.reachableByFlows(source).size shouldBe 8 } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryTests.scala index 2811a33d8d0e..9d4951a10f0f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/RubyTypeRecoveryTests.scala @@ -3,6 +3,7 @@ package io.joern.rubysrc2cpg.passes import io.joern.rubysrc2cpg.Config import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.joern.rubysrc2cpg.utils.PackageTable +import io.joern.x2cpg.passes.frontend.ImportsPass.CallToResolvedImportExt import io.joern.x2cpg.passes.frontend.ImportsPass.{ ResolvedMethod, ResolvedTypeDecl, @@ -70,8 +71,8 @@ class RubyTypeRecoveryTests ) "resolve 'x' identifier types despite shadowing" in { val List(xOuterScope, xInnerScope) = cpg.identifier("x").take(2).l - xOuterScope.dynamicTypeHintFullName shouldBe Seq("__builtin.Integer", "__builtin.String") - xInnerScope.dynamicTypeHintFullName shouldBe Seq("__builtin.Integer", "__builtin.String") + xOuterScope.possibleTypes shouldBe Seq("__builtin.Integer", "__builtin.String") + xInnerScope.possibleTypes shouldBe Seq("__builtin.Integer", "__builtin.String") } "resolve module constant type" in { @@ -140,10 +141,10 @@ class RubyTypeRecoveryTests // TODO Waiting for Module modelling to be done "resolve correct imports via tag nodes" ignore { val List(foo: ResolvedTypeDecl) = - cpg.file(".*foo.rb").ast.isCall.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.file(".*foo.rb").ast.isCall.where(_.referencedImports).toResolvedImport.toList: @unchecked foo.fullName shouldBe "dbi::program.DBI" val List(bar: ResolvedTypeDecl) = - cpg.file(".*bar.rb").ast.isCall.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + cpg.file(".*bar.rb").ast.isCall.where(_.referencedImports).toResolvedImport.toList: @unchecked bar.fullName shouldBe "foo.rb::program.FooModule" } @@ -202,7 +203,7 @@ class RubyTypeRecoveryTests |""".stripMargin).cpg "resolve correct imports via tag nodes" in { - val List(logging: ResolvedMethod, _) = cpg.call.where(_.referencedImports).tag.toResolvedImport.toList: @unchecked + val List(logging: ResolvedMethod, _) = cpg.call.where(_.referencedImports).toResolvedImport.toList: @unchecked logging.fullName shouldBe s"logger::program.Logger.${XDefines.ConstructorMethodName}" } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/CallCpgTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/CallCpgTests.scala index cc93e004a91d..05a2dd200b37 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/CallCpgTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/CallCpgTests.scala @@ -28,15 +28,11 @@ class CallCpgTests extends RubyCode2CpgFixture(withPostProcessing = true) { } "test astChildren" in { - val callNode = cpg.call.name("foo").head - val children = callNode.astChildren - children.size shouldBe 2 - - val firstChild = children.head - val secondChild = children.last + val callNode = cpg.call.name("foo").head + val List(_, firstArg, secondArg) = callNode.astChildren.l: @unchecked - firstChild.code shouldBe "\"a\"" - secondChild.code shouldBe "b" + firstArg.code shouldBe "\"a\"" + secondArg.code shouldBe "b" } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala index 4361d21718c0..ab8bcd52187f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala @@ -25,7 +25,7 @@ class DoBlockTest extends RubyCode2CpgFixture { name.name shouldBe "name" age.name shouldBe "age" - val List(value, unit) = nameMethod.local.l + val List(unit, value) = nameMethod.local.l value.name shouldBe "value" unit.name shouldBe "unit" } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala index 9e5ccfe1d5ef..c88bbc136216 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala @@ -181,8 +181,8 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { } "have correct structure for `self` identifier" in { - val cpg = code("puts self") - val List(self, _) = cpg.identifier.l + val cpg = code("puts self") + val List(self) = cpg.identifier("self").l self.typeFullName shouldBe Defines.Object self.code shouldBe "self" self.lineNumber shouldBe Some(1) @@ -190,8 +190,8 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { } "have correct structure for `__FILE__` identifier" in { - val cpg = code("puts __FILE__") - val List(file, _) = cpg.identifier.l + val cpg = code("puts __FILE__") + val List(file) = cpg.identifier("__FILE__").l file.typeFullName shouldBe "__builtin.String" file.code shouldBe "__FILE__" file.lineNumber shouldBe Some(1) @@ -199,8 +199,8 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { } "have correct structure for `__LINE__` identifier" in { - val cpg = code("puts __LINE__") - val List(line, _) = cpg.identifier.l + val cpg = code("puts __LINE__") + val List(line) = cpg.identifier("__LINE__").l line.typeFullName shouldBe "__builtin.Integer" line.code shouldBe "__LINE__" line.lineNumber shouldBe Some(1) @@ -208,8 +208,8 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { } "have correct structure for `__ENCODING__` identifier" in { - val cpg = code("puts __ENCODING__") - val List(encoding, _) = cpg.identifier.l + val cpg = code("puts __ENCODING__") + val List(encoding) = cpg.identifier("__ENCODING__").l encoding.typeFullName shouldBe Defines.Encoding encoding.code shouldBe "__ENCODING__" encoding.lineNumber shouldBe Some(1) @@ -1097,23 +1097,22 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { } // Change below test cases to focus on the argument of call `foo` - "have correct structure when a association is passed as an argument with parantheses" in { + "have correct structure when a association is passed as an argument with parentheses" in { val cpg = code("""foo(bar:)""".stripMargin) - cpg.argument.size shouldBe 2 - cpg.argument.l(0).code shouldBe "bar:" + val List(_, bar) = cpg.call("foo").argument.l + bar.code shouldBe "bar:" cpg.call.size shouldBe 2 val List(callNode, operatorNode) = cpg.call.l callNode.name shouldBe "foo" operatorNode.name shouldBe ".activeRecordAssociation" } - "have correct structure when a association is passed as an argument without parantheses" in { + "have correct structure when a association is passed as an argument without parentheses" in { val cpg = code("""foo bar:""".stripMargin) - cpg.argument.size shouldBe 2 - cpg.argument.l.head.code shouldBe "bar:" - + val List(_, bar) = cpg.call("foo").argument.l + bar.code shouldBe "bar:" cpg.call.size shouldBe 2 val List(callNode, operatorNode) = cpg.call.l callNode.name shouldBe "foo" @@ -1216,8 +1215,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { |val fileName = "AB\u0003\u0004\u0014\u0000\u0000\u0000\b\u0000\u0000\u0000!\u0000file" |""".stripMargin) - cpg.identifier.size shouldBe 1 - cpg.identifier.name.head shouldBe "fileName" + cpg.identifier("fileName").size shouldBe 1 cpg.literal.head.code .stripPrefix("\"") .stripSuffix("\"") @@ -1434,7 +1432,6 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { cpg.call.size shouldBe 2 cpg.call.name(".activeRecordAssociation").size shouldBe 1 - cpg.identifier.size shouldBe 2 cpg.identifier.name("a").size shouldBe 1 cpg.identifier.name("b").size shouldBe 1 } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala index 48e197ddcbd5..f211f810cfaa 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala @@ -4,6 +4,7 @@ import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.ControlStructureTypes import io.shiftleft.codepropertygraph.generated.nodes.{Block, ControlStructure} import io.shiftleft.semanticcpg.language.* + class ControlStructureTests extends RubyCode2CpgFixture { "CPG for code with doBlock iterating over a constant array" should { @@ -15,7 +16,7 @@ class ControlStructureTests extends RubyCode2CpgFixture { "recognise all identifier nodes" in { cpg.identifier.name("n").size shouldBe 1 - cpg.identifier.size shouldBe 2 // 1 identifier node is for `puts = typeDef(__builtin.puts)` + cpg.identifier.size shouldBe 3 // 1 identifier node is for `puts = typeDef(__builtin.puts)` and similarly for `each2` } "recognize all call nodes" in { @@ -56,7 +57,7 @@ class ControlStructureTests extends RubyCode2CpgFixture { "recognise all identifier nodes" in { cpg.identifier.name("n").size shouldBe 2 cpg.identifier.name("m").size shouldBe 1 - cpg.identifier.size shouldBe 5 + cpg.identifier.size shouldBe 6 cpg.method.name("fakeName").dotAst.l } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FunctionTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FunctionTests.scala index 68c0740b4e95..596f56ed1d2d 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FunctionTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/FunctionTests.scala @@ -36,7 +36,7 @@ class FunctionTests extends RubyCode2CpgFixture { cpg.identifier.name("age").size shouldBe 1 cpg.fieldAccess.fieldIdentifier.canonicalName("name").size shouldBe 2 cpg.fieldAccess.fieldIdentifier.canonicalName("age").size shouldBe 4 - cpg.identifier.size shouldBe 13 // 4 identifier node is for `puts = typeDef(__builtin.puts)` 1 node for class Person = typeDef + cpg.identifier.size shouldBe 16 // 4 identifier node is for `puts = typeDef(__builtin.puts)` 1 node for class Person = typeDef } "recognize all call nodes" in { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IdentifierTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IdentifierTests.scala index 7dd2f48fa3c0..351fc7065c67 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IdentifierTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/IdentifierTests.scala @@ -35,7 +35,8 @@ class IdentifierTests extends RubyCode2CpgFixture { cpg.identifier.name("num3").size shouldBe 1 cpg.identifier.name("sum").size shouldBe 2 cpg.identifier.name("ret").size shouldBe 2 - cpg.identifier.size shouldBe 16 // 2 identifier node is for methodRef's assigment + cpg.identifier.name("this").size shouldBe 2 + cpg.identifier.size shouldBe 18 // 2 identifier node is for methodRef's assigment } "identify a single call node" in { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala index 76997217302a..33409dd64a0c 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala @@ -32,7 +32,7 @@ class MiscTests extends RubyCode2CpgFixture { cpg.identifier.name("beginbool").size shouldBe 1 cpg.identifier.name("endbool").size shouldBe 1 cpg.call.name("puts").size shouldBe 1 - cpg.identifier.size shouldBe 7 // 1 identifier node is for `puts = typeDef(__builtin.puts)` + cpg.identifier.size shouldBe 8 // 1 identifier node is for `puts = typeDef(__builtin.puts)` } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala index d738b3a5bb01..0c9b70d55455 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala @@ -93,7 +93,7 @@ class SymbolTable[K <: SBKey](val keyFromNode: AstNode => Option[K]) { } def put(sbKey: K, typeFullNames: Set[String]): Set[String] = - if (typeFullNames.nonEmpty) { + if (!sbKey.identifier.isBlank && typeFullNames.nonEmpty) { val newEntry = coalesce(Set.empty, typeFullNames) table.put(sbKey, newEntry) newEntry @@ -122,6 +122,7 @@ class SymbolTable[K <: SBKey](val keyFromNode: AstNode => Option[K]) { def append(sbKey: K, typeFullNames: Set[String]): Set[String] = { table.get(sbKey) match { + case _ if sbKey.identifier.isBlank => Set.empty case Some(ts) if ts == typeFullNames => ts case Some(ts) if typeFullNames.nonEmpty => put(sbKey, coalesce(ts, typeFullNames)) case None if typeFullNames.nonEmpty => put(sbKey, coalesce(Set.empty, typeFullNames)) @@ -152,6 +153,18 @@ class SymbolTable[K <: SBKey](val keyFromNode: AstNode => Option[K]) { def view: MapView[K, Set[String]] = table.view + /** @return + * a deep copy of this symbol table. + */ + def copy(): SymbolTable[K] = { + val sb = SymbolTable[K](keyFromNode) + this.table.foreach { case (k, v) => sb.table.put(k, Set.from(v)) } + sb + } + def clear(): Unit = table.clear() + override def toString: String = + table.map { case (k, v) => s"$k -> [${v.mkString(",")}]" }.mkString("\n") + } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XImportResolverPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XImportResolverPass.scala index d2f78e7a32e3..a22230369f2b 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XImportResolverPass.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XImportResolverPass.scala @@ -49,11 +49,29 @@ object ImportsPass { def label: String def serialize: String + + def alias: String } - implicit class TagToResolvedImportExt(traversal: Iterator[Tag]) { + implicit class CallToResolvedImportExt(traversal: Iterator[Call]) { + def toResolvedImport: Iterator[ResolvedImport] = - traversal.flatMap(ResolvedImport.tagToResolvedImport) + traversal + .flatMap(c => + c.referencedImports.flatMap(i => + for { + alias <- i.importedAs + } yield c.tag.toResolvedImport(alias) + ) + ) + .flatten + .distinct + + } + + implicit class TagToResolvedImportExt(traversal: Iterator[Tag]) { + def toResolvedImport(alias: String): Iterator[ResolvedImport] = + traversal.flatMap(ResolvedImport.tagToResolvedImport(_, alias)) } object ResolvedImport { @@ -71,19 +89,19 @@ object ImportsPass { val OPT_BASE_PATH = "BASE_PATH" val OPT_NAME = "NAME" - def tagToResolvedImport(tag: Tag): Option[ResolvedImport] = Option(tag.name match { + def tagToResolvedImport(tag: Tag, alias: String): Option[ResolvedImport] = Option(tag.name match { case RESOLVED_METHOD => val opts = valueToOptions(tag.value) ResolvedMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) - case RESOLVED_TYPE_DECL => ResolvedTypeDecl(tag.value) + case RESOLVED_TYPE_DECL => ResolvedTypeDecl(tag.value, alias) case RESOLVED_MEMBER => val opts = valueToOptions(tag.value) - ResolvedMember(opts(OPT_BASE_PATH), opts(OPT_NAME)) + ResolvedMember(opts(OPT_BASE_PATH), opts(OPT_NAME), alias) case UNKNOWN_METHOD => val opts = valueToOptions(tag.value) UnknownMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) - case UNKNOWN_TYPE_DECL => UnknownTypeDecl(tag.value) - case UNKNOWN_IMPORT => UnknownImport(tag.value) + case UNKNOWN_TYPE_DECL => UnknownTypeDecl(tag.value, alias) + case UNKNOWN_IMPORT => UnknownImport(tag.value, alias) case _ => null }) @@ -113,13 +131,20 @@ object ImportsPass { .mkString(sep) } - case class ResolvedTypeDecl(fullName: String, override val label: String = RESOLVED_TYPE_DECL) - extends ResolvedImport { + case class ResolvedTypeDecl( + fullName: String, + override val alias: String, + override val label: String = RESOLVED_TYPE_DECL + ) extends ResolvedImport { override def serialize: String = fullName } - case class ResolvedMember(basePath: String, memberName: String, override val label: String = RESOLVED_MEMBER) - extends ResolvedImport { + case class ResolvedMember( + basePath: String, + memberName: String, + override val alias: String, + override val label: String = RESOLVED_MEMBER + ) extends ResolvedImport { override def serialize: String = Seq(OPT_BASE_PATH, basePath.encode, OPT_NAME, memberName.encode).mkString(sep) } @@ -136,11 +161,16 @@ object ImportsPass { .mkString(sep) } - case class UnknownTypeDecl(fullName: String, override val label: String = UNKNOWN_TYPE_DECL) extends ResolvedImport { + case class UnknownTypeDecl( + fullName: String, + override val alias: String, + override val label: String = UNKNOWN_TYPE_DECL + ) extends ResolvedImport { override def serialize: String = fullName } - case class UnknownImport(path: String, override val label: String = UNKNOWN_IMPORT) extends ResolvedImport { + case class UnknownImport(path: String, override val alias: String, override val label: String = UNKNOWN_IMPORT) + extends ResolvedImport { override def serialize: String = path } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeHintCallLinker.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeHintCallLinker.scala index 68b84ce88a52..b191c472055c 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeHintCallLinker.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeHintCallLinker.scala @@ -24,13 +24,15 @@ abstract class XTypeHintCallLinker(cpg: Cpg) extends CpgPass(cpg) { implicit protected val resolver: NoResolve.type = NoResolve private val fileNamePattern = Pattern.compile("^(.*(.py|.js|.rb)).*$") protected val pathSep: Char = '.' + private val methodReturnMap = + cpg.methodReturn.filterNot(_.typeFullName == "ANY").map { mr => mr.method.fullName -> mr.typeFullName }.toMap protected def calls: Iterator[Call] = cpg.call .nameNot(".*", ".*") .filter(c => calleeNames(c).nonEmpty && c.callee.isEmpty) protected def calleeNames(c: Call): Seq[String] = - c.dynamicTypeHintFullName.filterNot(_.equals("ANY")).distinct + (c.dynamicTypeHintFullName ++ c.possibleTypes).filterNot(_.equals("ANY")).distinct protected def callees(names: Seq[String]): List[Method] = cpg.method.fullNameExact(names: _*).toList @@ -88,15 +90,43 @@ abstract class XTypeHintCallLinker(cpg: Cpg) extends CpgPass(cpg) { } } - protected def setCallees(call: Call, methodNames: Seq[String], builder: DiffGraphBuilder): Unit = { - val nonDummyTypes = methodNames.filterNot(isDummyType) + protected def setCallees(call: Call, ms: Seq[String], builder: DiffGraphBuilder): Unit = { + + /** Filters in "most resolved" types, by looking at the common suffixes. + * + * @param ts + * the incoming types. + * @return + * the filtered set. + */ + def filterMostResolvedTypes(ts: Iterable[String]): Iterable[String] = { + ts.groupBy(_.split(pathSep).last).flatMap { case (_, xs) => + xs.sortBy(x => (XTypeRecovery.DummyTokens.count(x.contains), x.length)).headOption + } + } + + def resolveMethodRefs(m: String): String = + if (!m.endsWith(XTypeRecovery.DummyReturnType)) { + m.split(pathSep).toList match + case head :: next => + methodReturnMap.get((head +: next.take(next.length - 1)).mkString(pathSep.toString)) match + case Some(path) if next.nonEmpty => s"$path$pathSep${next.last}" + case Some(path) => path + case _ => m + case Nil => m + } else { + m + } + + val methodNames = ms.map(resolveMethodRefs).distinct + lazy val nonDummyTypes = methodNames.filterNot(isDummyType) + lazy val mostResolvedTypes = filterMostResolvedTypes(methodNames).toSeq if (methodNames.sizeIs == 1) { builder.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, methodNames.head) - builder.setNodeProperty( - call, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - call.dynamicTypeHintFullName.diff(methodNames) - ) + builder.setNodeProperty(call, PropertyNames.POSSIBLE_TYPES, call.possibleTypes.diff(methodNames)) + } else if (mostResolvedTypes.sizeIs == 1 && nonDummyTypes.isEmpty) { + builder.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, mostResolvedTypes.head) + builder.setNodeProperty(call, PropertyNames.POSSIBLE_TYPES, call.possibleTypes.diff(mostResolvedTypes)) } else if (methodNames.sizeIs > 1 && methodNames != nonDummyTypes) { setCallees(call, nonDummyTypes, builder) } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala index 96f1fdc6eaff..ad4461b1a6e4 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala @@ -1,5 +1,6 @@ package io.joern.x2cpg.passes.frontend +import io.joern.x2cpg.passes.frontend.ImportsPass.* import io.joern.x2cpg.{Defines, X2CpgConfig} import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes.* @@ -13,18 +14,21 @@ import overflowdb.BatchedUpdate import overflowdb.BatchedUpdate.DiffGraphBuilder import scopt.OParser -import java.util.concurrent.RecursiveTask import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{Callable, ExecutorService, Executors} import scala.annotation.tailrec import scala.collection.concurrent.TrieMap -import scala.collection.mutable +import scala.collection.immutable.Set +import scala.collection.{Iterator, mutable} +import scala.util.matching.Regex +import scala.util.{Failure, Success, Try} /** @param iterations * the number of iterations to run. * @param enabledDummyTypes * whether to enable placeholder dummy values for partially resolved types. */ -case class XTypeRecoveryConfig(iterations: Int = 2, enabledDummyTypes: Boolean = true) +case class TypeRecoveryConfig(iterations: Int = 2, enabledDummyTypes: Boolean = true) /** @param config * the user defined config. @@ -37,12 +41,13 @@ case class XTypeRecoveryConfig(iterations: Int = 2, enabledDummyTypes: Boolean = * @param stopEarly * indicates that we may stop type propagation earlier than the specified number of iterations. */ -case class XTypeRecoveryState( - config: XTypeRecoveryConfig = XTypeRecoveryConfig(), +case class TypeRecoveryState( + config: TypeRecoveryConfig = TypeRecoveryConfig(), currentIteration: Int = 0, isFieldCache: TrieMap[Long, Boolean] = TrieMap.empty[Long, Boolean], changesWereMade: AtomicBoolean = new AtomicBoolean(false), - stopEarly: AtomicBoolean + stopEarly: AtomicBoolean, + graphCache: GraphCache ) { lazy val isFinalIteration: Boolean = currentIteration == config.iterations - 1 @@ -52,38 +57,86 @@ case class XTypeRecoveryState( } +case class GraphCache( + methodReturnTypes: Map[String, Set[String]], + memberTypes: Map[(String, String), Set[String]], + typeDecls: Set[String] +) + /** In order to propagate types across compilation units, but avoid the poor scalability of a fixed-point algorithm, the * number of iterations can be configured using the iterations parameter. Note that iterations < 2 will not provide any * interprocedural type recovery capabilities. * @param cpg * the CPG to recovery types for. - * - * @tparam CompilationUnitType - * the AstNode type used to represent a compilation unit of the language. */ -abstract class XTypeRecoveryPass[CompilationUnitType <: AstNode]( - cpg: Cpg, - config: XTypeRecoveryConfig = XTypeRecoveryConfig() -) extends CpgPass(cpg) { +abstract class XTypeRecoveryPass(cpg: Cpg, config: TypeRecoveryConfig = TypeRecoveryConfig()) extends CpgPass(cpg) { + + import io.joern.x2cpg.passes.frontend.XTypeRecovery.AllNodeTypesFromNodeExt + + private def initGraphCache: GraphCache = + GraphCache( + methodReturnTypes = cpg.methodReturn + .map(mr => mr.method.fullName -> mr.allTypes.filterNot(_ == "ANY").toSet) + .filterNot(_._2.isEmpty) + .toMap, + memberTypes = cpg.member + .map(m => (m.typeDecl.fullName, m.name) -> m.allTypes.filterNot(_ == "ANY").toSet) + .filterNot(_._2.isEmpty) + .toMap, + typeDecls = cpg.typeDecl.fullName.toSet + ) override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = if (config.iterations > 0) { val stopEarly = new AtomicBoolean(false) - val state = XTypeRecoveryState(config, stopEarly = stopEarly) + val state = TypeRecoveryState(config, stopEarly = stopEarly, graphCache = initGraphCache) + val executor = Executors.newWorkStealingPool() try { Iterator.from(0).takeWhile(_ < config.iterations).foreach { i => - val newState = state.copy(currentIteration = i) - generateRecoveryPass(newState).createAndApply() + val newState = state.copy(currentIteration = i, graphCache = initGraphCache) + generateRecoveryPass(newState, executor).createAndApply() } // If dummy values are enabled and we are stopping early, we need one more round to propagate these dummy values - if (stopEarly.get() && config.enabledDummyTypes) - generateRecoveryPass(state.copy(currentIteration = config.iterations - 1)).createAndApply() + if (stopEarly.get() && config.enabledDummyTypes) { + generateRecoveryPass( + state.copy(currentIteration = config.iterations - 1, graphCache = initGraphCache), + executor + ) + .createAndApply() + } + postTypeRecoveryAndPropagation(builder) } finally { state.clear() + executor.shutdown() } } - protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[CompilationUnitType] + protected def generateRecoveryPass(state: TypeRecoveryState, executor: ExecutorService): XTypeRecovery + + /** A hook for the end of the type recovery and propagation. + */ + protected def postTypeRecoveryAndPropagation(builder: DiffGraphBuilder): Unit = { + linkMembersToTheirRefs(builder) + } + + private def linkMembersToTheirRefs(builder: DiffGraphBuilder): Unit = { + import XTypeRecovery.unknownTypePattern + // Set all now-typed fieldAccess calls to their referencing members (if they exist) + cpg.fieldAccess + .where( + _.and( + _.not(_.referencedMember), + _.argument(1).isIdentifier.typeFullNameNot(unknownTypePattern.pattern.pattern()) + ) + ) + .foreach { fieldAccess => + cpg.typeDecl + .fullNameExact(fieldAccess.argument(1).getKnownTypes.toSeq: _*) + .member + .nameExact(fieldAccess.fieldIdentifier.canonicalName.toSeq: _*) + .foreach(builder.addEdge(fieldAccess, _, EdgeTypes.REF)) + } + } } @@ -98,7 +151,7 @@ trait TypeRecoveryParserConfig[R <: X2CpgConfig[R]] { this: R => } def withTypePropagationIterations(value: Int): R = { - typePropagationIterations = value + this.typePropagationIterations = value this } @@ -123,38 +176,105 @@ trait TypeRecoveryParserConfig[R <: X2CpgConfig[R]] { this: R => * * @param cpg * the CPG to recovery types for. - * @tparam CompilationUnitType - * the AstNode type used to represent a compilation unit of the language. */ -abstract class XTypeRecovery[CompilationUnitType <: AstNode](cpg: Cpg, state: XTypeRecoveryState) extends CpgPass(cpg) { - - override def run(builder: DiffGraphBuilder): Unit = { - val changesWereMade = compilationUnit - .map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork()) - .map(_.get) - .reduceOption((a, b) => a || b) - .getOrElse(false) - if (!changesWereMade) state.stopEarly.set(true) - } +abstract class XTypeRecovery(cpg: Cpg, state: TypeRecoveryState, executor: ExecutorService) extends CpgPass(cpg) { - /** @return - * the compilation units as per how the language is compiled. e.g. file. - */ - def compilationUnit: Iterator[CompilationUnitType] + import io.joern.x2cpg.passes.frontend.XTypeRecovery.{AllNodeTypesFromIteratorExt, AllNodeTypesFromNodeExt} + + protected val logger: Logger = LoggerFactory.getLogger(getClass) + protected val initialSymbolTable: SymbolTable[LocalKey] = SymbolTable[LocalKey](SBKey.fromNodeToLocalKey) + + override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = + cpg.file.toArray.foreach(runOnCompilationUnit(builder, _)) - /** A factory method to generate a [[RecoverForXCompilationUnit]] task with the given parameters. + private def runOnCompilationUnit(builder: DiffGraphBuilder, part: File): Unit = { + importNodes(part).foreach(loadImports(_, initialSymbolTable)) + // Prune import names if the methods exist in the CPG + postVisitImports() + // Make a new task and absorb the resulting builder from each task + part.method.toArray + .map(methodToTask) + .map(executor.submit) + .map(task => Try(task.get())) + .foreach { + case Failure(exception) => + logger.error(s"Type recovery & propagation task failed for file '${part.name}'", exception) + case Success(diffGraph) => + builder.absorb(diffGraph) + } + if (!state.changesWereMade.get()) state.stopEarly.set(true) + } + + /** A factory method to generate a [[RecoverTypesForProcedure]] task with the given parameters. * * @param unit - * the compilation unit. + * the procedure. + * @return + * a forkable [[RecoverTypesForProcedure]] task. + */ + private def methodToTask(unit: Method): RecoverTypesForProcedure = + recoverTypesForProcedure(cpg, unit, initialSymbolTable.copy(), new DiffGraphBuilder, state) + + /** The entrypoint for a type recovery and propagation task. + * @param cpg + * the graph. + * @param procedure + * the target method. + * @param initialSymbolTable + * the initial symbol table containing imported symbols. * @param builder - * the graph builder. + * the builder. + * @param state + * state information for this iteration of the type recovery and propagation algorithm. * @return - * a forkable [[RecoverForXCompilationUnit]] task. + * a callable recovery task. + */ + protected def recoverTypesForProcedure( + cpg: Cpg, + procedure: Method, + initialSymbolTable: SymbolTable[LocalKey], + builder: DiffGraphBuilder, + state: TypeRecoveryState + ): RecoverTypesForProcedure + + protected def importNodes(cu: File): List[ResolvedImport] = cu.ast.isCall.referencedImports.flatMap(visitImport).l + + /** Visits an import and stores references in the symbol table as both an identifier and call. + */ + protected def visitImport(i: Import): Iterator[ResolvedImport] = { + import io.joern.x2cpg.passes.frontend.ImportsPass.* + + i.call.tag.flatMap(tag => i.importedAs.flatMap(ResolvedImport.tagToResolvedImport(tag, _))) + } + + /** Visits an import and stores references in the symbol table as both an identifier and call. + */ + protected def loadImports(i: ResolvedImport, symbolTable: SymbolTable[LocalKey]): Unit = i match { + case ResolvedMethod(fullName, alias, receiver, _) => + symbolTable.append(CallAlias(alias, receiver), fullName) + case ResolvedTypeDecl(fullName, alias, _) => + symbolTable.append(LocalVar(alias), fullName) + case ResolvedMember(basePath, memberName, alias, _) => + val matchingIdentifiers = cpg.method.fullNameExact(basePath).local + val matchingMembers = cpg.typeDecl.fullNameExact(basePath).member + val memberTypes = (matchingMembers ++ matchingIdentifiers) + .nameExact(memberName) + .getKnownTypes + symbolTable.append(LocalVar(alias), memberTypes) + case UnknownMethod(fullName, alias, receiver, _) => + symbolTable.append(CallAlias(alias, receiver), fullName) + case UnknownTypeDecl(fullName, alias, _) => + symbolTable.append(LocalVar(alias), fullName) + case UnknownImport(path, alias, _) => + symbolTable.append(CallAlias(alias), path) + symbolTable.append(LocalVar(alias), path) + } + + /** The initial import setting is over-approximated, so this step checks the CPG for any matches and prunes against + * these findings. If there are no findings, it will leave the table as is. The latter is significant for external + * types or methods. */ - def generateRecoveryForCompilationUnitTask( - unit: CompilationUnitType, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[CompilationUnitType] + protected def postVisitImports(): Unit = {} } @@ -162,10 +282,12 @@ object XTypeRecovery { private val logger = LoggerFactory.getLogger(getClass) - val DummyReturnType = "" - val DummyMemberLoad = "" - val DummyIndexAccess = "" - private lazy val DummyTokens: Set[String] = Set(DummyReturnType, DummyMemberLoad, DummyIndexAccess) + val DummyReturnType = "" + val DummyMemberLoad = "" + val DummyIndexAccess = "" + lazy val DummyTokens: Set[String] = Set(DummyReturnType, DummyMemberLoad, DummyIndexAccess) + + val unknownTypePattern: Regex = s"(i?)(UNKNOWN|ANY|${Defines.UnresolvedNamespace}).*".r def dummyMemberType(prefix: String, memberName: String, sep: Char = '.'): String = s"$prefix$sep$DummyMemberLoad($memberName)" @@ -201,31 +323,55 @@ object XTypeRecovery { ) } + // The below are convenience calls for accessing type properties, one day when this pass uses `Tag` nodes instead of + // the symbol table then perhaps this would work out better + implicit class AllNodeTypesFromNodeExt(x: StoredNode) { + + def allTypes: Iterator[String] = (Seq(x.property(PropertyNames.TYPE_FULL_NAME, "ANY")) ++ + x.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) ++ x.property( + PropertyNames.POSSIBLE_TYPES, + Seq.empty + )).iterator + + def getKnownTypes: Set[String] = { + x.allTypes.toSet.filterNot(XTypeRecovery.unknownTypePattern.matches) + } + } + + implicit class AllNodeTypesFromIteratorExt(x: Iterator[StoredNode]) { + def allTypes: Iterator[String] = x.flatMap(_.allTypes) + + def getKnownTypes: Set[String] = + x.allTypes.toSet.filterNot(XTypeRecovery.unknownTypePattern.matches) + } + } /** Performs type recovery from the root of a compilation unit level * * @param cpg * the graph. - * @param cu - * a compilation unit, e.g. file, procedure, type, etc. + * @param procedure + * some executable method or function. + * @param symbolTable + * stores type information for local structures that live within this compilation unit, e.g. local variables. and is + * pre-loaded with imported symbols. * @param builder - * the graph builder - * @tparam CompilationUnitType - * the AstNode type used to represent a compilation unit of the language. + * the graph builder, should be returned in the `call` function. + * @param state + * the state of the type recovery. */ -abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( +abstract class RecoverTypesForProcedure( cpg: Cpg, - cu: CompilationUnitType, + procedure: Method, + symbolTable: SymbolTable[LocalKey], builder: DiffGraphBuilder, - state: XTypeRecoveryState -) extends RecursiveTask[Boolean] { + state: TypeRecoveryState +) extends Callable[DiffGraphBuilder] { - protected val logger: Logger = LoggerFactory.getLogger(getClass) + import io.joern.x2cpg.passes.frontend.XTypeRecovery.{AllNodeTypesFromIteratorExt, AllNodeTypesFromNodeExt} - /** Stores type information for local structures that live within this compilation unit, e.g. local variables. - */ - protected val symbolTable = new SymbolTable[LocalKey](SBKey.fromNodeToLocalKey) + protected val logger: Logger = LoggerFactory.getLogger(getClass) /** The root of the target codebase. */ @@ -237,26 +383,35 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( /** New node tracking set. */ - protected val addedNodes = mutable.HashSet.empty[String] + private val addedNodes = mutable.HashSet.empty[String] - /** For tracking members and the type operations that need to be performed. Since these are mostly out of scope - * locally it helps to track these separately. + /** For tracking interprocedural nodes and the type operations that need to be performed. Since these are mostly out + * of scope locally it helps to track these separately. * * // TODO: Potentially a new use for a global table or modification to the symbol table? */ - protected val newTypesForMembers = mutable.HashMap.empty[Member, Set[String]] + private val newTypesForInterprocNodes = mutable.HashMap.empty[AstNode, Set[String]] /** Provides an entrypoint to add known symbols and their possible types. */ protected def prepopulateSymbolTable(): Unit = { - (cu.ast.isIdentifier ++ cu.ast.isCall ++ cu.ast.isLocal ++ cu.ast.isParameter) + val procAndParentProc = procedure.inAst.isMethod.l + val declarations = (procAndParentProc.local ++ procAndParentProc.parameter.filter(p => + if (p.index == 0) p.method == procedure else true + )).l + val identifiers = procAndParentProc.flatMap(_._identifierViaContainsOut).l + val calls = procedure.call + .nameNot(" c.name -> c.argument.headOption) + + (declarations ++ identifiers ++ calls) .filter(hasTypes) .foreach(prepopulateSymbolTableEntry) } protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match { case x @ (_: Identifier | _: Local | _: MethodParameterIn) => symbolTable.append(x, x.getKnownTypes) - case x: Call => symbolTable.append(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) + case x: Call => symbolTable.append(x, (x.methodFullName +: x.possibleTypes).filterNot(_.startsWith(" } @@ -266,22 +421,13 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( case x => x.getKnownTypes.nonEmpty } - protected def assignments: Iterator[Assignment] = - cu.ast.isCall.nameExact(Operators.assignment).map(new OpNodes.Assignment(_)) - - protected def members: Iterator[Member] = cu.ast.isMember - - protected def returns: Iterator[Return] = cu.ast.isReturn + protected def assignments: Iterator[Assignment] = procedure.assignment - protected def importNodes: Iterator[Import] = cu.ast.isCall.referencedImports + protected def returns: Iterator[Return] = procedure.methodReturn.toReturn - override def compute(): Boolean = try { - // Set known aliases that point to imports for local and external methods/modules - importNodes.foreach(visitImport) + override def call(): DiffGraphBuilder = try { // Look at symbols with existing type info prepopulateSymbolTable() - // Prune import names if the methods exist in the CPG - postVisitImports() // Populate local symbol table with assignments assignments.foreach(visitAssignments) // See if any new information are in the parameters of methods @@ -290,8 +436,10 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( setTypeInformation() // Entrypoint for any final changes postSetTypeInformation() - // Return number of changes - state.changesWereMade.get() + // Clear symbol table now in case the return value is blocked by the task queue + symbolTable.clear() + // Return diff graph + builder } finally { symbolTable.clear() } @@ -302,63 +450,27 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( s"$fileName#L$lineNo" } - /** Visits an import and stores references in the symbol table as both an identifier and call. - */ - protected def visitImport(i: Import): Unit = for { - resolvedImport <- i.call.tag - alias <- i.importedAs - } { - import io.joern.x2cpg.passes.frontend.ImportsPass.* - - ResolvedImport.tagToResolvedImport(resolvedImport).foreach { - case ResolvedMethod(fullName, alias, receiver, _) => - symbolTable.append(CallAlias(alias, receiver), fullName) - case ResolvedTypeDecl(fullName, _) => - symbolTable.append(LocalVar(alias), fullName) - case ResolvedMember(basePath, memberName, _) => - val matchingIdentifiers = cpg.method.fullNameExact(basePath).local - val matchingMembers = cpg.typeDecl.fullNameExact(basePath).member - val memberTypes = (matchingMembers ++ matchingIdentifiers) - .nameExact(memberName) - .getKnownTypes - symbolTable.append(LocalVar(alias), memberTypes) - case UnknownMethod(fullName, alias, receiver, _) => - symbolTable.append(CallAlias(alias, receiver), fullName) - case UnknownTypeDecl(fullName, _) => - symbolTable.append(LocalVar(alias), fullName) - case UnknownImport(path, _) => - symbolTable.append(CallAlias(alias), path) - symbolTable.append(LocalVar(alias), path) - } - } - - /** The initial import setting is over-approximated, so this step checks the CPG for any matches and prunes against - * these findings. If there are no findings, it will leave the table as is. The latter is significant for external - * types or methods. - */ - protected def postVisitImports(): Unit = {} - /** Using assignment and import information (in the global symbol table), will propagate these types in the symbol * table. * * @param a * assignment call pointer. */ - protected def visitAssignments(a: Assignment): Set[String] = { - a.argumentOut.l match { - case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) - case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c) - case List(x: Identifier, y: Identifier) => visitIdentifierAssignedToIdentifier(x, y) - case List(i: Identifier, l: Literal) if state.isFirstIteration => visitIdentifierAssignedToLiteral(i, l) - case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m) - case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t) - case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i) - case List(x: Call, y: Call) => visitCallAssignedToCall(x, y) - case List(c: Call, l: Literal) if state.isFirstIteration => visitCallAssignedToLiteral(c, l) - case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m) - case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b) - case _ => Set.empty - } + protected def visitAssignments(a: Assignment): Set[String] = visitAssignmentArguments(a.argumentOut.l) + + protected def visitAssignmentArguments(args: List[AstNode]): Set[String] = args match { + case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) + case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c) + case List(x: Identifier, y: Identifier) => visitIdentifierAssignedToIdentifier(x, y) + case List(i: Identifier, l: Literal) if state.isFirstIteration => visitIdentifierAssignedToLiteral(i, l) + case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m) + case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t) + case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i) + case List(x: Call, y: Call) => visitCallAssignedToCall(x, y) + case List(c: Call, l: Literal) if state.isFirstIteration => visitCallAssignedToLiteral(c, l) + case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m) + case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b) + case _ => Set.empty } /** Visits an identifier being assigned to the result of some operation. @@ -505,11 +617,14 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( // We have been able to resolve the type inter-procedurally associateTypes(i, globalTypes) } else if (baseTypes.nonEmpty) { + lazy val existingMembers = cpg.typeDecl.fullNameExact(baseTypes.toSeq: _*).member.nameExact(fieldName) if (baseTypes.equals(symbolTable.get(LocalVar(fieldFullName)))) { associateTypes(i, baseTypes) - } else { + } else if (existingMembers.isEmpty) { // If not available, use a dummy variable that can be useful for call matching associateTypes(i, baseTypes.map(t => XTypeRecovery.dummyMemberType(t, fieldName, pathSep))) + } else { + Set.empty } } else { // Assign dummy @@ -544,7 +659,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( protected def visitIdentifierAssignedToCallRetVal(i: Identifier, c: Call): Set[String] = { if (symbolTable.contains(c)) { val callReturns = methodReturnValues(symbolTable.get(c).toSeq) - associateTypes(i, callReturns) + associateTypes(i, callReturns ++ i.getKnownTypes) } else if (c.argument.exists(_.argumentIndex == 0)) { val callFullNames = (c.argument(0) match { case i: Identifier if symbolTable.contains(LocalVar(i.name)) => symbolTable.get(LocalVar(i.name)) @@ -562,12 +677,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( /** Will attempt to find the return values of a method if in the CPG, otherwise will give a dummy value. */ protected def methodReturnValues(methodFullNames: Seq[String]): Set[String] = { - val rs = cpg.method - .fullNameExact(methodFullNames: _*) - .methodReturn - .flatMap(mr => mr.typeFullName +: mr.dynamicTypeHintFullName) - .filterNot(_.equals("ANY")) - .toSet + val rs = methodFullNames.flatMap(state.graphCache.methodReturnTypes.getOrElse(_, Set.empty)).toSet if (rs.isEmpty) methodFullNames.map(_.concat(s"$pathSep${XTypeRecovery.DummyReturnType}")).toSet else rs } @@ -611,7 +721,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( */ protected def getTypesFromCall(c: Call): Set[String] = c.name match { case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(c)))) - case _ if symbolTable.contains(c) => symbolTable.get(c) + case _ if symbolTable.contains(c) => methodReturnValues(symbolTable.get(c).toSeq) case Operators.indexAccess => getIndexAccessTypes(c) case n => logger.debug(s"Unknown RHS call type '$n' @ ${debugLocation(c)}") @@ -694,6 +804,8 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( // TODO: Handle this case better val callCode = if (c.code.contains("(")) c.code.substring(c.code.indexOf("(")) else c.code XTypeRecovery.dummyMemberType(callCode, f.canonicalName, pathSep) + case ::(_: TypeRef, ::(f: FieldIdentifier, _)) => + f.canonicalName case xs => logger.warn(s"Unhandled field structure ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(fa)}") wrapName("") @@ -815,7 +927,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( /** Visits an identifier that is the target of a cast operation. */ protected def visitIdentifierAssignedToCast(i: Identifier, c: Call): Set[String] = - associateTypes(i, (c.typeFullName +: c.dynamicTypeHintFullName).filterNot(_ == "ANY").toSet) + associateTypes(i, c.getKnownTypes) protected def getFieldBaseType(base: Identifier, fi: FieldIdentifier): Set[String] = getFieldBaseType(base.name, fi.canonicalName) @@ -825,13 +937,13 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( .get(LocalVar(baseName)) .flatMap(t => typeDeclIterator(t).member.nameExact(fieldName)) .typeFullNameNot("ANY") - .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) + .flatMap(_.getKnownTypes) .toSet protected def visitReturns(ret: Return): Unit = { val m = ret.method val existingTypes = mutable.HashSet.from( - (m.methodReturn.typeFullName +: m.methodReturn.dynamicTypeHintFullName) + (m.methodReturn.typeFullName +: (m.methodReturn.dynamicTypeHintFullName ++ m.methodReturn.possibleTypes)) .filterNot(_ == "ANY") ) @tailrec @@ -845,7 +957,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( .fullNameExact(ts.map(_.compUnitFullName).toSeq: _*) .member .nameExact(sym.identifier) - .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) + .flatMap(m => m.typeFullName +: (m.dynamicTypeHintFullName ++ m.possibleTypes)) .filterNot { x => x == "ANY" || x == "this" } .toSet if (cpgTypes.nonEmpty) cpgTypes @@ -869,19 +981,22 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( } val returnTypes = extractTypes(ret.argumentOut.l) existingTypes.addAll(returnTypes) - builder.setNodeProperty(ret.method.methodReturn, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, existingTypes) + newTypesForInterprocNodes.updateWith(m.methodReturn) { + case Some(xs) if existingTypes.nonEmpty => Some(xs ++ existingTypes) + case None if existingTypes.nonEmpty => Some(existingTypes.toSet) + case _ => None + } } /** Using an entry from the symbol table, will queue the CPG modification to persist the recovered type information. */ protected def setTypeInformation(): Unit = { - cu.ast + (procedure.capturedByMethodRef.referencedMethod ++ procedure).ast .collect { - case n: Local => n - case n: Call => n - case n: Expression => n - case n: MethodParameterIn if state.isFinalIteration => n - case n: MethodReturn if state.isFinalIteration => n + case n: Local => n + case n: Expression => n + case n: MethodParameterIn => n + case n: MethodReturn => n } .foreach { case x: Local if symbolTable.contains(x) => storeNodeTypeInfo(x, symbolTable.get(x).toSeq) @@ -900,8 +1015,10 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( setTypeInformationForRecCall(x, Option(x), x.argument.l) case _ => } - // Set types in an atomic way - newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfo(m, ts.toSeq) } + // Set types in an "atomic" way + newTypesForInterprocNodes.foreach { case (m, ts) => + storeDefaultTypeInfo(m, ts.toSeq) + } } protected def createCallFromIdentifierTypeFullName(typeFullName: String, callName: String): String = @@ -948,8 +1065,13 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( } protected def setTypeForDynamicDispatchCall(call: Call, i: Identifier): Unit = { - val idHints = symbolTable.get(i) + val idHints = if (i.name == "this" || i.name == "self") { + i.inAst.collectFirst { case x: Method => x.typeDecl.fullName.toSet }.getOrElse(symbolTable.get(i)) + } else { + symbolTable.get(i) + } val callTypes = symbolTable.get(call) + persistType(i, idHints) if (callTypes.isEmpty && !call.name.startsWith("")) // For now, calls are treated as function pointers and thus the type should point to the method @@ -986,7 +1108,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( /** In the case this field access is a function pointer, we would want to make sure this has a method ref. */ - private def handlePotentialFunctionPointer( + protected def handlePotentialFunctionPointer( funcPtr: Expression, baseTypes: Set[String], funcName: String, @@ -1103,7 +1225,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( storedNode match { case m: Member => // To avoid overwriting member updates, we store them elsewhere until the end - newTypesForMembers.updateWith(m) { + newTypesForInterprocNodes.updateWith(m) { case Some(ts) => Some(ts ++ types) case None => Some(types.toSet) } @@ -1121,11 +1243,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = if (types.nonEmpty) { state.changesWereMade.compareAndSet(false, true) - builder.setNodeProperty( - c, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - (c.dynamicTypeHintFullName ++ types).distinct - ) + builder.setNodeProperty(c, PropertyNames.POSSIBLE_TYPES, (c.getKnownTypes ++ types).toSeq) } /** Allows one to modify the types assigned to identifiers. @@ -1135,18 +1253,69 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( /** Allows one to modify the types assigned to nodes otherwise. */ - protected def storeDefaultTypeInfo(n: StoredNode, types: Seq[String]): Unit = + protected def storeDefaultTypeInfo(n: StoredNode, types: Seq[String]): Unit = { if (types.toSet != n.getKnownTypes) { state.changesWereMade.compareAndSet(false, true) - setTypes(n, (n.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) ++ types).distinct) + setTypes(n, (n.getKnownTypes ++ types).toSeq) } + } /** If there is only 1 type hint then this is set to the `typeFullName` property and `dynamicTypeHintFullName` is * cleared. If not then `dynamicTypeHintFullName` is set to the types. */ - protected def setTypes(n: StoredNode, types: Seq[String]): Unit = - if (types.size == 1) builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head) - else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types) + protected def setTypes(n: StoredNode, types: Seq[String]): Unit = { + val pattern = s"""^\\(([\\w_]+)\\)(.*)""".r + + def resolveDummyMethodReturn(t: String): Set[String] = { + val split = t.split(s"$pathSep${XTypeRecovery.DummyReturnType}", -1) + if (split.length - 1 > 1) Set(t) + else + split.toList match + case methodFullName :: suffix => + state.graphCache.methodReturnTypes + .get(methodFullName) + .map(xs => xs.map(x => s"$x${suffix.mkString}")) + .getOrElse(Set(t)) + case Nil => Set(t) + } + + def resolveMemberAccess(t: String): Set[String] = if (t.contains(XTypeRecovery.DummyMemberLoad)) { + val split = t.split(s"$pathSep${XTypeRecovery.DummyMemberLoad}", -1) + if (split.length - 1 > 1) { + Set(t) + } else { + split.toList match + case baseType :: next => + next.mkString match + case pattern(key, suffix) => + state.graphCache.memberTypes + .get((baseType, key)) + .map(xs => xs.map(x => s"$x${suffix.mkString}")) + .getOrElse(Set(t)) + case _ => Set(t) + case Nil => + Set(t) + } + } else { + Set(t) + } + + def filterMostResolvedTypes(ts: Iterable[String]): Iterable[String] = { + ts.groupBy(_.split(pathSep).last).flatMap { case (_, xs) => + xs.sortBy(x => (XTypeRecovery.DummyTokens.count(x.contains), x.length)).headOption + } + } + + val resolvedTypes = filterMostResolvedTypes( + types + .flatMap(resolveDummyMethodReturn) + .flatMap(resolveMemberAccess) + ).toSeq.distinct + val realTypes = resolvedTypes.filter(state.graphCache.typeDecls.contains) + if (realTypes.size == 1) builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, realTypes.head) + else if (resolvedTypes.size == 1) builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, resolvedTypes.head) + else builder.setNodeProperty(n, PropertyNames.POSSIBLE_TYPES, resolvedTypes) + } /** Allows one to modify the types assigned to locals. */ @@ -1158,26 +1327,4 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( */ protected def postSetTypeInformation(): Unit = {} - private val unknownTypePattern = s"(i?)(UNKNOWN|ANY|${Defines.UnresolvedNamespace}).*".r - - // The below are convenience calls for accessing type properties, one day when this pass uses `Tag` nodes instead of - // the symbol table then perhaps this would work out better - implicit class AllNodeTypesFromNodeExt(x: StoredNode) { - def allTypes: Iterator[String] = (x.property(PropertyNames.TYPE_FULL_NAME, "ANY") +: x.property( - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - Seq.empty - )).iterator - - def getKnownTypes: Set[String] = { - x.allTypes.toSet.filterNot(unknownTypePattern.matches) - } - } - - implicit class AllNodeTypesFromIteratorExt(x: Iterator[StoredNode]) { - def allTypes: Iterator[String] = x.flatMap(_.allTypes) - - def getKnownTypes: Set[String] = - x.allTypes.toSet.filterNot(unknownTypePattern.matches) - } - }