Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions internal/ast/symbol.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ type Symbol struct {

type SymbolTable map[string]*Symbol

// GetOrInit returns the symbol table, or initializes it if it's nil.
// This will modify whatever holds the SymbolTable, so is not safe for concurrent use.
func (s *SymbolTable) GetOrInit() SymbolTable {
if *s == nil {
*s = make(SymbolTable)
}
return *s
}

const InternalSymbolNamePrefix = "\xFE" // Invalid UTF8 sequence, will never occur as IdentifierName

const (
Expand Down
19 changes: 0 additions & 19 deletions internal/ast/utilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,6 @@ func GetSymbolId(symbol *Symbol) SymbolId {
return SymbolId(id)
}

func GetSymbolTable(data *SymbolTable) SymbolTable {
if *data == nil {
*data = make(SymbolTable)
}
return *data
}

func GetMembers(symbol *Symbol) SymbolTable {
return GetSymbolTable(&symbol.Members)
}

func GetExports(symbol *Symbol) SymbolTable {
return GetSymbolTable(&symbol.Exports)
}

func GetLocals(container *Node) SymbolTable {
return GetSymbolTable(&container.LocalsContainerData().Locals)
}

// Determines if a node is missing (either `nil` or empty)
func NodeIsMissing(node *Node) bool {
return node == nil || node.Loc.Pos() == node.Loc.End() && node.Loc.Pos() >= 0 && node.Kind != KindEndOfFile
Expand Down
58 changes: 35 additions & 23 deletions internal/binder/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,25 @@ func GetSymbolNameForPrivateIdentifier(containingClassSymbol *ast.Symbol, descri
return ast.InternalSymbolNamePrefix + "#" + strconv.Itoa(int(ast.GetSymbolId(containingClassSymbol))) + "@" + description
}

func getMembers(symbol *ast.Symbol) ast.SymbolTable {
return symbol.Members.GetOrInit()
}

func getExports(symbol *ast.Symbol) ast.SymbolTable {
return symbol.Exports.GetOrInit()
}

func getLocals(container *ast.Node) ast.SymbolTable {
return container.LocalsContainerData().Locals.GetOrInit()
}

func (b *Binder) declareModuleMember(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol {
hasExportModifier := ast.GetCombinedModifierFlags(node)&ast.ModifierFlagsExport != 0
if symbolFlags&ast.SymbolFlagsAlias != 0 {
if node.Kind == ast.KindExportSpecifier || (node.Kind == ast.KindImportEqualsDeclaration && hasExportModifier) {
return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
}
return b.declareSymbol(ast.GetLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
return b.declareSymbol(getLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
}
// Exported module members are given 2 symbols: A local symbol that is classified with an ExportValue flag,
// and an associated export symbol with all the correct flags set on it. There are 2 main reasons:
Expand All @@ -378,33 +390,33 @@ func (b *Binder) declareModuleMember(node *ast.Node, symbolFlags ast.SymbolFlags
// and should never be merged directly with other augmentation, and the latter case would be possible if automatic merge is allowed.
if !ast.IsAmbientModule(node) && (hasExportModifier || b.container.Flags&ast.NodeFlagsExportContext != 0) {
if !ast.IsLocalsContainer(b.container) || (ast.HasSyntacticModifier(node, ast.ModifierFlagsDefault) && b.getDeclarationName(node) == ast.InternalSymbolNameMissing) {
return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
// No local symbol for an unnamed default!
}
exportKind := ast.SymbolFlagsNone
if symbolFlags&ast.SymbolFlagsValue != 0 {
exportKind = ast.SymbolFlagsExportValue
}
local := b.declareSymbol(ast.GetLocals(b.container), nil /*parent*/, node, exportKind, symbolExcludes)
local.ExportSymbol = b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
local := b.declareSymbol(getLocals(b.container), nil /*parent*/, node, exportKind, symbolExcludes)
local.ExportSymbol = b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
node.ExportableData().LocalSymbol = local
return local
}
return b.declareSymbol(ast.GetLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
return b.declareSymbol(getLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
}

func (b *Binder) declareClassMember(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol {
if ast.IsStatic(node) {
return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
}
return b.declareSymbol(ast.GetMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
}

func (b *Binder) declareSourceFileMember(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol {
if ast.IsExternalModule(b.file) {
return b.declareModuleMember(node, symbolFlags, symbolExcludes)
}
return b.declareSymbol(ast.GetLocals(b.file.AsNode()), nil /*parent*/, node, symbolFlags, symbolExcludes)
return b.declareSymbol(getLocals(b.file.AsNode()), nil /*parent*/, node, symbolFlags, symbolExcludes)
}

func (b *Binder) declareSymbolAndAddToSymbolTable(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol {
Expand All @@ -416,14 +428,14 @@ func (b *Binder) declareSymbolAndAddToSymbolTable(node *ast.Node, symbolFlags as
case ast.KindClassExpression, ast.KindClassDeclaration:
return b.declareClassMember(node, symbolFlags, symbolExcludes)
case ast.KindEnumDeclaration:
return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
case ast.KindTypeLiteral, ast.KindJSDocTypeLiteral, ast.KindObjectLiteralExpression, ast.KindInterfaceDeclaration, ast.KindJsxAttributes:
return b.declareSymbol(ast.GetMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
return b.declareSymbol(getMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes)
case ast.KindFunctionType, ast.KindConstructorType, ast.KindCallSignature, ast.KindConstructSignature, ast.KindJSDocSignature,
ast.KindIndexSignature, ast.KindMethodDeclaration, ast.KindMethodSignature, ast.KindConstructor, ast.KindGetAccessor,
ast.KindSetAccessor, ast.KindFunctionDeclaration, ast.KindFunctionExpression, ast.KindArrowFunction,
ast.KindClassStaticBlockDeclaration, ast.KindTypeAliasDeclaration, ast.KindMappedType:
return b.declareSymbol(ast.GetLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
return b.declareSymbol(getLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes)
}
panic("Unhandled case in declareSymbolAndAddToSymbolTable")
}
Expand Down Expand Up @@ -727,7 +739,7 @@ func (b *Binder) bindSourceFileIfExternalModule() {
b.bindSourceFileAsExternalModule()
// Create symbol equivalent for the module.exports = {}
originalSymbol := b.file.Symbol
b.declareSymbol(ast.GetSymbolTable(&b.file.Symbol.Exports), b.file.Symbol, b.file.AsNode(), ast.SymbolFlagsProperty, ast.SymbolFlagsAll)
b.declareSymbol(b.file.Symbol.Exports.GetOrInit(), b.file.Symbol, b.file.AsNode(), ast.SymbolFlagsProperty, ast.SymbolFlagsAll)
b.file.Symbol = originalSymbol
}
}
Expand Down Expand Up @@ -788,7 +800,7 @@ func (b *Binder) bindNamespaceExportDeclaration(node *ast.Node) {
case !node.Parent.AsSourceFile().IsDeclarationFile:
b.errorOnNode(node, diagnostics.Global_module_exports_may_only_appear_in_declaration_files)
default:
b.declareSymbol(ast.GetSymbolTable(&b.file.Symbol.GlobalExports), b.file.Symbol, node, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes)
b.declareSymbol(b.file.Symbol.GlobalExports.GetOrInit(), b.file.Symbol, node, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes)
}
}

Expand All @@ -805,12 +817,12 @@ func (b *Binder) bindExportDeclaration(node *ast.Node) {
b.bindAnonymousDeclaration(node, ast.SymbolFlagsExportStar, b.getDeclarationName(node))
} else if decl.ExportClause == nil {
// All export * declarations are collected in an __export symbol
b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, ast.SymbolFlagsExportStar, ast.SymbolFlagsNone)
b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, ast.SymbolFlagsExportStar, ast.SymbolFlagsNone)
} else if ast.IsNamespaceExport(decl.ExportClause) {
// declareSymbol walks up parents to find name text, parent _must_ be set
// but won't be set by the normal binder walk until `bindChildren` later on.
setParent(decl.ExportClause, node)
b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), decl.ExportClause, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes)
b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), decl.ExportClause, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes)
}
}

Expand All @@ -825,7 +837,7 @@ func (b *Binder) bindExportAssignment(node *ast.Node) {
}
// If there is an `export default x;` alias declaration, can't `export default` anything else.
// (In contrast, you can still have `export default function f() {}` and `export default interface I {}`.)
symbol := b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, flags, ast.SymbolFlagsAll)
symbol := b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, flags, ast.SymbolFlagsAll)
if node.AsExportAssignment().IsExportEquals {
// Will be an error later, since the module already has other exports. Just make sure this has a valueDeclaration set.
SetValueDeclaration(symbol, node)
Expand Down Expand Up @@ -904,12 +916,12 @@ func (b *Binder) bindClassLikeDeclaration(node *ast.Node) {
// module might have an exported variable called 'prototype'. We can't allow that as
// that would clash with the built-in 'prototype' for the class.
prototypeSymbol := b.newSymbol(ast.SymbolFlagsProperty|ast.SymbolFlagsPrototype, "prototype")
symbolExport := ast.GetExports(symbol)[prototypeSymbol.Name]
symbolExport := getExports(symbol)[prototypeSymbol.Name]
if symbolExport != nil {
setParent(name, node)
b.errorOnNode(symbolExport.Declarations[0], diagnostics.Duplicate_identifier_0, ast.SymbolName(prototypeSymbol))
}
ast.GetExports(symbol)[prototypeSymbol.Name] = prototypeSymbol
getExports(symbol)[prototypeSymbol.Name] = prototypeSymbol
prototypeSymbol.Parent = symbol
}

Expand Down Expand Up @@ -978,7 +990,7 @@ func (b *Binder) bindFunctionPropertyAssignment(node *ast.Node) {
b.bindAnonymousDeclaration(node, ast.SymbolFlagsProperty|ast.SymbolFlagsAssignment, ast.InternalSymbolNameComputed)
addLateBoundAssignmentDeclarationToSymbol(node, funcSymbol)
} else {
b.declareSymbol(ast.GetExports(funcSymbol), funcSymbol, node, ast.SymbolFlagsProperty|ast.SymbolFlagsAssignment, ast.SymbolFlagsPropertyExcludes)
b.declareSymbol(getExports(funcSymbol), funcSymbol, node, ast.SymbolFlagsProperty|ast.SymbolFlagsAssignment, ast.SymbolFlagsPropertyExcludes)
}
}
}
Expand Down Expand Up @@ -1039,7 +1051,7 @@ func (b *Binder) bindParameter(node *ast.Node) {
if ast.IsParameterPropertyDeclaration(node, node.Parent) {
classDeclaration := node.Parent.Parent
flags := ast.SymbolFlagsProperty | core.IfElse(decl.QuestionToken != nil, ast.SymbolFlagsOptional, ast.SymbolFlagsNone)
b.declareSymbol(ast.GetMembers(classDeclaration.Symbol()), classDeclaration.Symbol(), node, flags, ast.SymbolFlagsPropertyExcludes)
b.declareSymbol(getMembers(classDeclaration.Symbol()), classDeclaration.Symbol(), node, flags, ast.SymbolFlagsPropertyExcludes)
}
}

Expand Down Expand Up @@ -1086,7 +1098,7 @@ func (b *Binder) bindBlockScopedDeclaration(node *ast.Node, symbolFlags ast.Symb
}
fallthrough
default:
b.declareSymbol(ast.GetLocals(b.blockScopeContainer), nil /*parent*/, node, symbolFlags, symbolExcludes)
b.declareSymbol(getLocals(b.blockScopeContainer), nil /*parent*/, node, symbolFlags, symbolExcludes)
}
}

Expand All @@ -1105,7 +1117,7 @@ func (b *Binder) bindTypeParameter(node *ast.Node) {
if node.Parent.Kind == ast.KindInferType {
container := b.getInferTypeContainer(node.Parent)
if container != nil {
b.declareSymbol(ast.GetLocals(container), nil /*parent*/, node, ast.SymbolFlagsTypeParameter, ast.SymbolFlagsTypeParameterExcludes)
b.declareSymbol(getLocals(container), nil /*parent*/, node, ast.SymbolFlagsTypeParameter, ast.SymbolFlagsTypeParameterExcludes)
} else {
b.bindAnonymousDeclaration(node, ast.SymbolFlagsTypeParameter, b.getDeclarationName(node))
}
Expand Down
14 changes: 7 additions & 7 deletions internal/checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ func (c *Checker) mergeModuleAugmentation(moduleName *ast.Node) {
}) {
merged := c.mergeSymbol(moduleAugmentation.Symbol, mainModule, true /*unidirectional*/)
// moduleName will be a StringLiteral since this is not `declare global`.
ast.GetSymbolTable(&c.patternAmbientModuleAugmentations)[moduleName.Text()] = merged
c.patternAmbientModuleAugmentations.GetOrInit()[moduleName.Text()] = merged
} else {
if mainModule.Exports[ast.InternalSymbolNameExportStar] != nil && len(moduleAugmentation.Symbol.Exports) != 0 {
// We may need to merge the module augmentation's exports into the target symbols of the resolved exports
Expand Down Expand Up @@ -11159,10 +11159,10 @@ func (c *Checker) mergeSymbol(target *ast.Symbol, source *ast.Symbol, unidirecti
}
target.Declarations = append(target.Declarations, source.Declarations...)
if source.Members != nil {
c.mergeSymbolTable(ast.GetSymbolTable(&target.Members), source.Members, unidirectional, nil)
c.mergeSymbolTable(target.Members.GetOrInit(), source.Members, unidirectional, nil)
}
if source.Exports != nil {
c.mergeSymbolTable(ast.GetSymbolTable(&target.Exports), source.Exports, unidirectional, target)
c.mergeSymbolTable(target.Members.GetOrInit(), source.Exports, unidirectional, target)
}
if !unidirectional {
c.recordMergedSymbol(target, source)
Expand Down Expand Up @@ -12152,7 +12152,7 @@ func (c *Checker) getCommonJsExportEquals(exported *ast.Symbol, moduleSymbol *as
merged = c.cloneSymbol(exported)
}
merged.Flags |= ast.SymbolFlagsValueModule
mergedExports := ast.GetExports(merged)
mergedExports := merged.Exports.GetOrInit()
for name, s := range moduleSymbol.Exports {
if name != ast.InternalSymbolNameExportEquals {
if existing, ok := mergedExports[name]; ok {
Expand Down Expand Up @@ -17593,9 +17593,9 @@ func (c *Checker) getPropertyOfUnionOrIntersectionType(t *Type, name string, ski
func (c *Checker) getUnionOrIntersectionProperty(t *Type, name string, skipObjectFunctionPropertyAugment bool) *ast.Symbol {
var cache ast.SymbolTable
if skipObjectFunctionPropertyAugment {
cache = ast.GetSymbolTable(&t.AsUnionOrIntersectionType().propertyCacheWithoutFunctionPropertyAugment)
cache = t.AsUnionOrIntersectionType().propertyCacheWithoutFunctionPropertyAugment.GetOrInit()
} else {
cache = ast.GetSymbolTable(&t.AsUnionOrIntersectionType().propertyCache)
cache = t.AsUnionOrIntersectionType().propertyCache.GetOrInit()
}
if prop := cache[name]; prop != nil {
return prop
Expand All @@ -17605,7 +17605,7 @@ func (c *Checker) getUnionOrIntersectionProperty(t *Type, name string, skipObjec
cache[name] = prop
// Propagate an entry from the non-augmented cache to the augmented cache unless the property is partial.
if skipObjectFunctionPropertyAugment && prop.CheckFlags&ast.CheckFlagsPartial == 0 {
augmentedCache := ast.GetSymbolTable(&t.AsUnionOrIntersectionType().propertyCache)
augmentedCache := t.AsUnionOrIntersectionType().propertyCache.GetOrInit()
if augmentedCache[name] == nil {
augmentedCache[name] = prop
}
Expand Down
Loading