fork of https://github.com/sourcegraph/zoekt
1package scala.meta.internal.mtags
2
3import java.io.UncheckedIOException
4import java.nio.CharBuffer
5import java.nio.charset.StandardCharsets
6
7import scala.collection.concurrent.TrieMap
8import scala.util.Properties
9
10import scala.meta.Dialect
11import scala.meta.inputs.Input
12import scala.meta.internal.io.FileIO
13import scala.meta.internal.io.PathIO
14import scala.meta.internal.mtags.MtagsEnrichments._
15import scala.meta.internal.semanticdb.Scala._
16import scala.meta.internal.{semanticdb => s}
17import scala.meta.io.AbsolutePath
18
19final case class SymbolLocation(
20 path: AbsolutePath,
21 range: Option[s.Range]
22)
23
24/**
25 * Index split on buckets per dialect in order to have a constant time
26 * and low memory footprint to infer dialect for SymbolDefinition because
27 * it's used in WorkspaceSymbolProvider
28 *
29 * @param toplevels keys are non-trivial toplevel symbols and values are the file
30 * the symbols are defined in.
31 * @param definitions keys are global symbols and the values are the files the symbols
32 * are defined in. Difference between toplevels and definitions
33 * is that toplevels contains only symbols generated by ScalaToplevelMtags
34 * while definitions contains only symbols generated by ScalaMtags.
35 */
36class SymbolIndexBucket(
37 toplevels: TrieMap[String, Set[AbsolutePath]],
38 definitions: TrieMap[String, Set[SymbolLocation]],
39 sourceJars: ClasspathLoader,
40 toIndexSource: AbsolutePath => AbsolutePath = identity,
41 mtags: Mtags,
42 dialect: Dialect
43) {
44
45 import SymbolIndexBucket.stdLibPatches
46
47 def close(): Unit = sourceJars.close()
48
49 def addSourceDirectory(dir: AbsolutePath): List[(String, AbsolutePath)] = {
50 if (sourceJars.addEntry(dir)) {
51 dir.listRecursive.toList.flatMap {
52 case source if source.isScala =>
53 addSourceFile(source, Some(dir)).map(sym => (sym, source))
54 case _ =>
55 List.empty
56 }
57 } else
58 List.empty
59 }
60
61 def addSourceJar(jar: AbsolutePath): List[(String, AbsolutePath)] = {
62 if (sourceJars.addEntry(jar)) {
63 FileIO.withJarFileSystem(jar, create = false) { root =>
64 try {
65 root.listRecursive.toList.flatMap {
66 case source if source.isScala =>
67 addSourceFile(source, None, Some(jar)).map(sym => (sym, source))
68 case _ =>
69 List.empty
70 }
71 } catch {
72 // this happens in broken jars since file from FileWalker should exists
73 case _: UncheckedIOException => Nil
74 }
75 }
76 } else
77 List.empty
78 }
79
80 def addIndexedSourceJar(
81 jar: AbsolutePath,
82 symbols: List[(String, AbsolutePath)]
83 ): Unit = {
84 if (sourceJars.addEntry(jar)) {
85 val patched =
86 if (stdLibPatches.isScala3Library(jar))
87 symbols.map { case (sym, path) =>
88 (stdLibPatches.patchSymbol(sym), path)
89 }
90 else symbols
91
92 patched.foreach { case (sym, path) =>
93 val acc = toplevels.getOrElse(sym, Set.empty)
94 toplevels(sym) = acc + path
95 }
96 }
97 }
98
99 def addSourceFile(
100 source: AbsolutePath,
101 sourceDirectory: Option[AbsolutePath],
102 fromSourceJar: Option[AbsolutePath] = None
103 ): List[String] = {
104 val uri = source.toIdeallyRelativeURI(sourceDirectory)
105 val symbols = indexSource(source, uri, dialect)
106
107 val patched =
108 fromSourceJar match {
109 case Some(jar) if stdLibPatches.isScala3Library(jar) =>
110 symbols.map(stdLibPatches.patchSymbol)
111 case _ => symbols
112 }
113
114 patched.foreach { symbol =>
115 val acc = toplevels.getOrElse(symbol, Set.empty)
116 toplevels(symbol) = acc + source
117 }
118 symbols
119 }
120
121 private def indexSource(
122 source: AbsolutePath,
123 uri: String,
124 dialect: Dialect
125 ): List[String] = {
126 val text = FileIO.slurp(source, StandardCharsets.UTF_8)
127 val input = Input.VirtualFile(uri, text)
128 val sourceToplevels = mtags.toplevels(input, dialect)
129 if (source.isAmmoniteScript)
130 sourceToplevels
131 else
132 sourceToplevels.filter(sym => !isTrivialToplevelSymbol(uri, sym))
133 }
134
135 // Returns true if symbol is com/foo/Bar# and path is /com/foo/Bar.scala
136 // Such symbols are "trivial" because their definition location can be computed
137 // on the fly.
138 private def isTrivialToplevelSymbol(path: String, symbol: String): Boolean = {
139 val pathBuffer =
140 CharBuffer.wrap(path).subSequence(1, path.length - ".scala".length)
141 val symbolBuffer =
142 CharBuffer.wrap(symbol).subSequence(0, symbol.length - 1)
143 pathBuffer.equals(symbolBuffer)
144 }
145
146 def addToplevelSymbol(
147 path: String,
148 source: AbsolutePath,
149 toplevel: String
150 ): Unit = {
151 if (source.isAmmoniteScript || !isTrivialToplevelSymbol(path, toplevel)) {
152 val acc = toplevels.getOrElse(toplevel, Set.empty)
153 toplevels(toplevel) = acc + source
154 }
155 }
156
157 def query(symbol: Symbol): List[SymbolDefinition] =
158 query0(symbol, symbol)
159
160 /**
161 * Returns the file where symbol is defined, if any.
162 *
163 * Uses two strategies to recover from missing symbol definitions:
164 * - try to enter the toplevel symbol definition, then lookup symbol again.
165 * - if the symbol is synthetic, for examples from a case class of macro annotation,
166 * fall back to related symbols from the enclosing class, see `DefinitionAlternatives`.
167 *
168 * @param querySymbol The original symbol that was queried by the user.
169 * @param symbol The symbol that
170 * @return
171 */
172 private def query0(
173 querySymbol: Symbol,
174 symbol: Symbol
175 ): List[SymbolDefinition] = {
176 if (!definitions.contains(symbol.value)) {
177 // Fallback 1: enter the toplevel symbol definition
178 val toplevel = symbol.toplevel
179 toplevels.get(toplevel.value) match {
180 case Some(files) =>
181 files.foreach(addMtagsSourceFile)
182 case _ =>
183 loadFromSourceJars(trivialPaths(toplevel))
184 .orElse(loadFromSourceJars(modulePaths(toplevel)))
185 .foreach(_.foreach(addMtagsSourceFile))
186 }
187 }
188 if (!definitions.contains(symbol.value)) {
189 // Fallback 2: guess related symbols from the enclosing class.
190 DefinitionAlternatives(symbol)
191 .flatMap(alternative => query0(querySymbol, alternative))
192 } else {
193 definitions
194 .get(symbol.value)
195 .map { paths =>
196 paths.map { location =>
197 SymbolDefinition(
198 querySymbol = querySymbol,
199 definitionSymbol = symbol,
200 path = location.path,
201 dialect = dialect,
202 range = location.range
203 )
204 }.toList
205 }
206 .getOrElse(List.empty)
207 }
208 }
209 // similar as addSourceFile except indexes all global symbols instead of
210 // only non-trivial toplevel symbols.
211 private def addMtagsSourceFile(file: AbsolutePath): Unit = {
212 val docs: s.TextDocuments = PathIO.extension(file.toNIO) match {
213 case "scala" | "java" | "sc" =>
214 val language = file.toLanguage
215 val toIndexSource0 = toIndexSource(file)
216 val input = toIndexSource0.toInput
217 val document =
218 stdLibPatches.patchDocument(
219 file,
220 mtags.index(language, input, dialect)
221 )
222 s.TextDocuments(List(document))
223 case _ =>
224 s.TextDocuments(Nil)
225 }
226 if (docs.documents.nonEmpty) {
227 addTextDocuments(file, docs)
228 }
229 }
230
231 // Records all global symbol definitions.
232 private def addTextDocuments(
233 file: AbsolutePath,
234 docs: s.TextDocuments
235 ): Unit = {
236 docs.documents.foreach { document =>
237 document.occurrences.foreach { occ =>
238 if (occ.symbol.isGlobal && occ.role.isDefinition) {
239 val acc = definitions.getOrElse(occ.symbol, Set.empty)
240 definitions.put(occ.symbol, acc + SymbolLocation(file, occ.range))
241 } else {
242 // do nothing, we only care about global symbol definitions.
243 }
244 }
245 }
246 }
247
248 // Returns the first path that resolves to a file.
249 private def loadFromSourceJars(
250 paths: List[String]
251 ): Option[List[AbsolutePath]] = {
252 paths match {
253 case Nil => None
254 case head :: tail =>
255 sourceJars.loadAll(head) match {
256 case Nil => loadFromSourceJars(tail)
257 case values => Some(values)
258 }
259 }
260 }
261
262 // Returns relative file paths for trivial toplevel symbols, example:
263 // Input: scala/collection/immutable/List#
264 // Output: scala/collection/immutable/List.scala
265 // scala/collection/immutable/List.java
266 private def trivialPaths(toplevel: Symbol): List[String] = {
267 val noExtension = toplevel.value.stripSuffix(".").stripSuffix("#")
268 List(
269 noExtension + ".scala",
270 noExtension + ".java"
271 )
272 }
273
274 private def modulePaths(toplevel: Symbol): List[String] = {
275 if (Properties.isJavaAtLeast("9")) {
276 val noExtension = toplevel.value.stripSuffix(".").stripSuffix("#")
277 val javaSymbol = noExtension.replace("/", ".")
278 for {
279 cls <- sourceJars.loadClass(javaSymbol).toList
280 // note(@tgodzik) Modules are only available in Java 9+, so we need to invoke this reflectively
281 module <- Option(
282 cls.getClass().getMethod("getModule").invoke(cls)
283 ).toList
284 moduleName <- Option(
285 module.getClass().getMethod("getName").invoke(module)
286 ).toList
287 file <- List(
288 s"$moduleName/$noExtension.java",
289 s"$moduleName/$noExtension.scala"
290 )
291 } yield file
292 } else {
293 Nil
294 }
295 }
296}
297
298object SymbolIndexBucket {
299
300 def empty(
301 dialect: Dialect,
302 mtags: Mtags,
303 toIndexSource: AbsolutePath => AbsolutePath
304 ): SymbolIndexBucket =
305 new SymbolIndexBucket(
306 TrieMap.empty,
307 TrieMap.empty,
308 new ClasspathLoader(),
309 toIndexSource,
310 mtags,
311 dialect
312 )
313
314 /**
315 * Scala 3 has a specific package that adds / replaces some symbols in scala.Predef + scala.language
316 * https://github.com/lampepfl/dotty/blob/main/library/src/scala/runtime/stdLibPatches/
317 * We need to do the same to correctly provide location for symbols obtained from semanticdb.
318 */
319 object stdLibPatches {
320 val packageName = "scala/runtime/stdLibPatches"
321
322 def isScala3Library(jar: AbsolutePath): Boolean =
323 jar.filename.startsWith("scala3-library_3")
324
325 def isScala3LibraryPatchSource(file: AbsolutePath): Boolean = {
326 file.jarPath.exists(
327 isScala3Library(_)
328 ) && file.parent.filename == "stdLibPatches"
329 }
330
331 def patchSymbol(sym: String): String =
332 sym.replace(packageName, "scala")
333
334 def patchDocument(
335 file: AbsolutePath,
336 doc: s.TextDocument
337 ): s.TextDocument = {
338 if (isScala3LibraryPatchSource(file)) {
339 val occs =
340 doc.occurrences.map(occ => occ.copy(symbol = patchSymbol(occ.symbol)))
341
342 doc.copy(occurrences = occs)
343 } else doc
344 }
345
346 }
347}