diff --git a/go.mod b/go.mod index 067f388aa4..6be49ef853 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874 github.com/google/go-cmp v0.7.0 github.com/pkg/diff v0.0.0-20241224192749-4e6772a4315c + golang.org/x/sync v0.11.0 golang.org/x/sys v0.31.0 gotest.tools/v3 v3.5.2 ) @@ -14,7 +15,6 @@ require ( require ( github.com/matryer/moq v0.5.3 // indirect golang.org/x/mod v0.23.0 // indirect - golang.org/x/sync v0.11.0 // indirect golang.org/x/tools v0.30.0 // indirect ) diff --git a/internal/api/api.go b/internal/api/api.go index 42191d6915..6f8e16d6bc 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "errors" "fmt" @@ -129,7 +130,7 @@ func (api *API) IsWatchEnabled() bool { return false } -func (api *API) HandleRequest(id int, method string, payload []byte) ([]byte, error) { +func (api *API) HandleRequest(ctx context.Context, method string, payload []byte) ([]byte, error) { params, err := unmarshalPayload(method, payload) if err != nil { return nil, err @@ -155,27 +156,27 @@ func (api *API) HandleRequest(id int, method string, payload []byte) ([]byte, er return encodeJSON(api.LoadProject(params.(*LoadProjectParams).ConfigFileName)) case MethodGetSymbolAtPosition: params := params.(*GetSymbolAtPositionParams) - return encodeJSON(api.GetSymbolAtPosition(params.Project, params.FileName, int(params.Position))) + return encodeJSON(api.GetSymbolAtPosition(ctx, params.Project, params.FileName, int(params.Position))) case MethodGetSymbolsAtPositions: params := params.(*GetSymbolsAtPositionsParams) return encodeJSON(core.TryMap(params.Positions, func(position uint32) (any, error) { - return api.GetSymbolAtPosition(params.Project, params.FileName, int(position)) + return api.GetSymbolAtPosition(ctx, params.Project, params.FileName, int(position)) })) case MethodGetSymbolAtLocation: params := params.(*GetSymbolAtLocationParams) - return encodeJSON(api.GetSymbolAtLocation(params.Project, params.Location)) + return encodeJSON(api.GetSymbolAtLocation(ctx, params.Project, params.Location)) case MethodGetSymbolsAtLocations: params := params.(*GetSymbolsAtLocationsParams) return encodeJSON(core.TryMap(params.Locations, func(location Handle[ast.Node]) (any, error) { - return api.GetSymbolAtLocation(params.Project, location) + return api.GetSymbolAtLocation(ctx, params.Project, location) })) case MethodGetTypeOfSymbol: params := params.(*GetTypeOfSymbolParams) - return encodeJSON(api.GetTypeOfSymbol(params.Project, params.Symbol)) + return encodeJSON(api.GetTypeOfSymbol(ctx, params.Project, params.Symbol)) case MethodGetTypesOfSymbols: params := params.(*GetTypesOfSymbolsParams) return encodeJSON(core.TryMap(params.Symbols, func(symbol Handle[ast.Symbol]) (any, error) { - return api.GetTypeOfSymbol(params.Project, symbol) + return api.GetTypeOfSymbol(ctx, params.Project, symbol) })) default: return nil, fmt.Errorf("unhandled API method %q", method) @@ -223,12 +224,14 @@ func (api *API) LoadProject(configFileName string) (*ProjectResponse, error) { return data, nil } -func (api *API) GetSymbolAtPosition(projectId Handle[project.Project], fileName string, position int) (*SymbolResponse, error) { +func (api *API) GetSymbolAtPosition(ctx context.Context, projectId Handle[project.Project], fileName string, position int) (*SymbolResponse, error) { project, ok := api.projects[projectId] if !ok { return nil, errors.New("project not found") } - symbol, err := project.LanguageService().GetSymbolAtPosition(fileName, position) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + symbol, err := languageService.GetSymbolAtPosition(ctx, fileName, position) if err != nil || symbol == nil { return nil, err } @@ -239,7 +242,7 @@ func (api *API) GetSymbolAtPosition(projectId Handle[project.Project], fileName return data, nil } -func (api *API) GetSymbolAtLocation(projectId Handle[project.Project], location Handle[ast.Node]) (*SymbolResponse, error) { +func (api *API) GetSymbolAtLocation(ctx context.Context, projectId Handle[project.Project], location Handle[ast.Node]) (*SymbolResponse, error) { project, ok := api.projects[projectId] if !ok { return nil, errors.New("project not found") @@ -262,7 +265,9 @@ func (api *API) GetSymbolAtLocation(projectId Handle[project.Project], location if node == nil { return nil, fmt.Errorf("node of kind %s not found at position %d in file %q", kind.String(), pos, sourceFile.FileName()) } - symbol := project.LanguageService().GetSymbolAtLocation(node) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + symbol := languageService.GetSymbolAtLocation(ctx, node) if symbol == nil { return nil, nil } @@ -273,7 +278,7 @@ func (api *API) GetSymbolAtLocation(projectId Handle[project.Project], location return data, nil } -func (api *API) GetTypeOfSymbol(projectId Handle[project.Project], symbolHandle Handle[ast.Symbol]) (*TypeResponse, error) { +func (api *API) GetTypeOfSymbol(ctx context.Context, projectId Handle[project.Project], symbolHandle Handle[ast.Symbol]) (*TypeResponse, error) { project, ok := api.projects[projectId] if !ok { return nil, errors.New("project not found") @@ -284,7 +289,9 @@ func (api *API) GetTypeOfSymbol(projectId Handle[project.Project], symbolHandle if !ok { return nil, fmt.Errorf("symbol %q not found", symbolHandle) } - t := project.LanguageService().GetTypeOfSymbol(symbol) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + t := languageService.GetTypeOfSymbol(ctx, symbol) if t == nil { return nil, nil } diff --git a/internal/api/server.go b/internal/api/server.go index 2b60dd855a..061e5d7bae 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -2,13 +2,16 @@ package api import ( "bufio" + "context" "encoding/binary" "encoding/json" "fmt" "io" + "strconv" "sync" "github.com/microsoft/typescript-go/internal/bundled" + "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/project" "github.com/microsoft/typescript-go/internal/vfs" "github.com/microsoft/typescript-go/internal/vfs/osvfs" @@ -254,7 +257,7 @@ func (s *Server) handleRequest(method string, payload []byte) ([]byte, error) { case "echo": return payload, nil default: - return s.api.HandleRequest(s.requestId, method, payload) + return s.api.HandleRequest(core.WithRequestID(context.Background(), strconv.Itoa(s.requestId)), method, payload) } } diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index 3922e6ffd3..3458c8687c 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -39,7 +39,8 @@ foo.bar;` } p := compiler.NewProgram(opts) p.BindSourceFiles() - c := p.GetTypeChecker() + c, done := p.GetTypeChecker(t.Context()) + defer done() file := p.GetSourceFile("/foo.ts") interfaceId := file.Statements.Nodes[0].Name() varId := file.Statements.Nodes[1].AsVariableStatement().DeclarationList.AsVariableDeclarationList().Declarations.Nodes[0].Name() diff --git a/internal/checker/exports.go b/internal/checker/exports.go index 4bf67fbba0..d04b7a97cc 100644 --- a/internal/checker/exports.go +++ b/internal/checker/exports.go @@ -48,3 +48,7 @@ func (c *Checker) GetTypeOfPropertyOfContextualType(t *Type, name string) *Type func GetDeclarationModifierFlagsFromSymbol(s *ast.Symbol) ast.ModifierFlags { return getDeclarationModifierFlagsFromSymbol(s) } + +func (c *Checker) WasCanceled() bool { + return c.wasCanceled +} diff --git a/internal/compiler/checkerpool.go b/internal/compiler/checkerpool.go new file mode 100644 index 0000000000..17949cd568 --- /dev/null +++ b/internal/compiler/checkerpool.go @@ -0,0 +1,90 @@ +package compiler + +import ( + "context" + "iter" + "slices" + "sync" + + "github.com/microsoft/typescript-go/internal/ast" + "github.com/microsoft/typescript-go/internal/checker" + "github.com/microsoft/typescript-go/internal/core" +) + +type CheckerPool interface { + GetChecker(ctx context.Context) (*checker.Checker, func()) + GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) + GetAllCheckers(ctx context.Context) ([]*checker.Checker, func()) + Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] +} + +type checkerPool struct { + checkerCount int + program *Program + + createCheckersOnce sync.Once + checkers []*checker.Checker + fileAssociations map[*ast.SourceFile]*checker.Checker +} + +var _ CheckerPool = (*checkerPool)(nil) + +func newCheckerPool(checkerCount int, program *Program) *checkerPool { + pool := &checkerPool{ + program: program, + checkerCount: checkerCount, + checkers: make([]*checker.Checker, checkerCount), + } + + return pool +} + +func (p *checkerPool) GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) { + p.createCheckers() + checker := p.fileAssociations[file] + return checker, noop +} + +func (p *checkerPool) GetChecker(ctx context.Context) (*checker.Checker, func()) { + p.createCheckers() + checker := p.checkers[0] + return checker, noop +} + +func (p *checkerPool) createCheckers() { + p.createCheckersOnce.Do(func() { + wg := core.NewWorkGroup(p.program.singleThreaded()) + for i := range p.checkerCount { + wg.Queue(func() { + p.checkers[i] = checker.NewChecker(p.program) + }) + } + + wg.RunAndWait() + + p.fileAssociations = make(map[*ast.SourceFile]*checker.Checker, len(p.program.files)) + for i, file := range p.program.files { + p.fileAssociations[file] = p.checkers[i%p.checkerCount] + } + }) +} + +func (p *checkerPool) GetAllCheckers(ctx context.Context) ([]*checker.Checker, func()) { + p.createCheckers() + return p.checkers, noop +} + +func (p *checkerPool) Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] { + checkerIndex := slices.Index(p.checkers, checker) + return func(yield func(*ast.SourceFile) bool) { + for i, file := range p.program.files { + if i%p.checkerCount == checkerIndex { + if !yield(file) { + return + } + } + } + } +} + +func noop() {} diff --git a/internal/compiler/emitHost.go b/internal/compiler/emitHost.go index 0f73ed940f..7afff9d8ce 100644 --- a/internal/compiler/emitHost.go +++ b/internal/compiler/emitHost.go @@ -1,6 +1,8 @@ package compiler import ( + "context" + "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/printer" @@ -32,7 +34,11 @@ func (host *emitHost) WriteFile(fileName string, text string, writeByteOrderMark } func (host *emitHost) GetEmitResolver(file *ast.SourceFile, skipDiagnostics bool) printer.EmitResolver { - checker := host.program.GetTypeCheckerForFile(file) + // The context and done function don't matter in tsc, currently the only caller of this function. + // But if this ever gets used by LSP code, we'll need to thread the context properly and pass the + // done function to the caller to ensure resources are cleaned up at the end of the request. + checker, done := host.program.GetTypeCheckerForFile(context.TODO(), file) + defer done() return checker.GetEmitResolver(file, skipDiagnostics) } diff --git a/internal/compiler/program.go b/internal/compiler/program.go index cb1a6c5e3e..d05b3fffa8 100644 --- a/internal/compiler/program.go +++ b/internal/compiler/program.go @@ -27,6 +27,7 @@ type ProgramOptions struct { SingleThreaded core.Tristate ProjectReference []core.ProjectReference ConfigFileParsingDiagnostics []*ast.Diagnostic + CreateCheckerPool func(*Program) CheckerPool } type Program struct { @@ -35,9 +36,7 @@ type Program struct { compilerOptions *core.CompilerOptions configFileName string nodeModules map[string]*ast.SourceFile - checkers []*checker.Checker - checkersOnce sync.Once - checkersByFile map[*ast.SourceFile]*checker.Checker + checkerPool CheckerPool currentDirectory string configFileParsingDiagnostics []*ast.Diagnostic @@ -79,6 +78,11 @@ func NewProgram(options ProgramOptions) *Program { if p.compilerOptions == nil { p.compilerOptions = &core.CompilerOptions{} } + if p.programOptions.CreateCheckerPool != nil { + p.checkerPool = p.programOptions.CreateCheckerPool(p) + } else { + p.checkerPool = newCheckerPool(core.IfElse(p.singleThreaded(), 1, 4), p) + } // p.maxNodeModuleJsDepth = p.options.MaxNodeModuleJsDepth @@ -212,56 +216,34 @@ func (p *Program) BindSourceFiles() { } func (p *Program) CheckSourceFiles(ctx context.Context) { - p.createCheckers() wg := core.NewWorkGroup(p.singleThreaded()) - for index, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(ctx) + defer done() + for _, checker := range checkers { wg.Queue(func() { - for i := index; i < len(p.files); i += len(p.checkers) { - checker.CheckSourceFile(ctx, p.files[i]) + for file := range p.checkerPool.Files(checker) { + checker.CheckSourceFile(ctx, file) } }) } wg.RunAndWait() } -func (p *Program) createCheckers() { - p.checkersOnce.Do(func() { - p.checkers = make([]*checker.Checker, core.IfElse(p.singleThreaded(), 1, 4)) - wg := core.NewWorkGroup(p.singleThreaded()) - for i := range p.checkers { - wg.Queue(func() { - p.checkers[i] = checker.NewChecker(p) - }) - } - wg.RunAndWait() - p.checkersByFile = make(map[*ast.SourceFile]*checker.Checker) - for i, file := range p.files { - p.checkersByFile[file] = p.checkers[i%len(p.checkers)] - } - }) -} - // Return the type checker associated with the program. -func (p *Program) GetTypeChecker() *checker.Checker { - p.createCheckers() - // Just use the first (and possibly only) checker for checker requests. Such requests are likely - // to obtain types through multiple API calls and we want to ensure that those types are created - // by the same checker so they can interoperate. - return p.checkers[0] +func (p *Program) GetTypeChecker(ctx context.Context) (*checker.Checker, func()) { + return p.checkerPool.GetChecker(ctx) } -func (p *Program) GetTypeCheckers() []*checker.Checker { - p.createCheckers() - return p.checkers +func (p *Program) GetTypeCheckers(ctx context.Context) ([]*checker.Checker, func()) { + return p.checkerPool.GetAllCheckers(ctx) } // Return a checker for the given file. We may have multiple checkers in concurrent scenarios and this // method returns the checker that was tasked with checking the file. Note that it isn't possible to mix // types obtained from different checkers, so only non-type data (such as diagnostics or string // representations of types) should be obtained from checkers returned by this method. -func (p *Program) GetTypeCheckerForFile(file *ast.SourceFile) *checker.Checker { - p.createCheckers() - return p.checkersByFile[file] +func (p *Program) GetTypeCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) { + return p.checkerPool.GetCheckerForFile(ctx, file) } func (p *Program) GetResolvedModule(file *ast.SourceFile, moduleReference string) *ast.SourceFile { @@ -294,17 +276,19 @@ func (p *Program) GetSemanticDiagnostics(ctx context.Context, sourceFile *ast.So return p.getDiagnosticsHelper(ctx, sourceFile, true /*ensureBound*/, true /*ensureChecked*/, p.getSemanticDiagnosticsForFile) } -func (p *Program) GetGlobalDiagnostics() []*ast.Diagnostic { - p.createCheckers() +func (p *Program) GetGlobalDiagnostics(ctx context.Context) []*ast.Diagnostic { var globalDiagnostics []*ast.Diagnostic - for _, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(ctx) + defer done() + for _, checker := range checkers { globalDiagnostics = append(globalDiagnostics, checker.GetGlobalDiagnostics()...) } + return SortAndDeduplicateDiagnostics(globalDiagnostics) } -func (p *Program) GetOptionsDiagnostics() []*ast.Diagnostic { - return SortAndDeduplicateDiagnostics(append(p.GetGlobalDiagnostics(), p.getOptionsDiagnosticsOfConfigFile()...)) +func (p *Program) GetOptionsDiagnostics(ctx context.Context) []*ast.Diagnostic { + return SortAndDeduplicateDiagnostics(append(p.GetGlobalDiagnostics(ctx), p.getOptionsDiagnosticsOfConfigFile()...)) } func (p *Program) getOptionsDiagnosticsOfConfigFile() []*ast.Diagnostic { @@ -332,14 +316,20 @@ func (p *Program) getSemanticDiagnosticsForFile(ctx context.Context, sourceFile if checker.SkipTypeChecking(sourceFile, p.compilerOptions) { return nil } + var fileChecker *checker.Checker + var done func() if sourceFile != nil { - fileChecker = p.GetTypeCheckerForFile(sourceFile) + fileChecker, done = p.checkerPool.GetCheckerForFile(ctx, sourceFile) + defer done() } diags := slices.Clip(sourceFile.BindDiagnostics()) + checkers, closeCheckers := p.checkerPool.GetAllCheckers(ctx) + defer closeCheckers() + // Ask for diags from all checkers; checking one file may add diagnostics to other files. // These are deduplicated later. - for _, checker := range p.checkers { + for _, checker := range checkers { if sourceFile == nil || checker == fileChecker { diags = append(diags, checker.GetDiagnostics(ctx, sourceFile)...) } else { @@ -482,7 +472,9 @@ func (p *Program) SymbolCount() int { for _, file := range p.files { count += file.SymbolCount } - for _, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(context.Background()) + defer done() + for _, checker := range checkers { count += int(checker.SymbolCount) } return count @@ -490,7 +482,9 @@ func (p *Program) SymbolCount() int { func (p *Program) TypeCount() int { var count int - for _, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(context.Background()) + defer done() + for _, checker := range checkers { count += int(checker.TypeCount) } return count @@ -498,7 +492,9 @@ func (p *Program) TypeCount() int { func (p *Program) InstantiationCount() int { var count int - for _, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(context.Background()) + defer done() + for _, checker := range checkers { count += int(checker.TotalInstantiationCount) } return count diff --git a/internal/core/context.go b/internal/core/context.go new file mode 100644 index 0000000000..f755cdedba --- /dev/null +++ b/internal/core/context.go @@ -0,0 +1,21 @@ +package core + +import "context" + +type key int + +var requestIDKey key + +func WithRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, requestIDKey, id) +} + +func GetRequestID(ctx context.Context) string { + if ctx == nil { + return "" + } + if id, ok := ctx.Value(requestIDKey).(string); ok { + return id + } + return "" +} diff --git a/internal/execute/tsc.go b/internal/execute/tsc.go index 959817f882..e26528f36c 100644 --- a/internal/execute/tsc.go +++ b/internal/execute/tsc.go @@ -239,25 +239,26 @@ type compileAndEmitResult struct { func compileAndEmit(sys System, program *compiler.Program, reportDiagnostic diagnosticReporter) (result compileAndEmitResult) { // todo: check if third return needed after execute is fully implemented + ctx := context.Background() options := program.Options() allDiagnostics := program.GetConfigFileParsingDiagnostics() // todo: early exit logic and append diagnostics - diagnostics := program.GetSyntacticDiagnostics(context.Background(), nil) + diagnostics := program.GetSyntacticDiagnostics(ctx, nil) if len(diagnostics) == 0 { bindStart := time.Now() - _ = program.GetBindDiagnostics(context.Background(), nil) + _ = program.GetBindDiagnostics(ctx, nil) result.bindTime = time.Since(bindStart) - diagnostics = append(diagnostics, program.GetOptionsDiagnostics()...) + diagnostics = append(diagnostics, program.GetOptionsDiagnostics(ctx)...) if options.ListFilesOnly.IsFalse() { // program.GetBindDiagnostics(nil) - diagnostics = append(diagnostics, program.GetGlobalDiagnostics()...) + diagnostics = append(diagnostics, program.GetGlobalDiagnostics(ctx)...) } } if len(diagnostics) == 0 { checkStart := time.Now() - diagnostics = append(diagnostics, program.GetSemanticDiagnostics(context.Background(), nil)...) + diagnostics = append(diagnostics, program.GetSemanticDiagnostics(ctx, nil)...) result.checkTime = time.Since(checkStart) } // TODO: declaration diagnostics diff --git a/internal/ls/api.go b/internal/ls/api.go index 62fd233a6f..4393806e72 100644 --- a/internal/ls/api.go +++ b/internal/ls/api.go @@ -1,6 +1,7 @@ package ls import ( + "context" "errors" "fmt" @@ -14,7 +15,7 @@ var ( ErrNoTokenAtPosition = errors.New("no token found at position") ) -func (l *LanguageService) GetSymbolAtPosition(fileName string, position int) (*ast.Symbol, error) { +func (l *LanguageService) GetSymbolAtPosition(ctx context.Context, fileName string, position int) (*ast.Symbol, error) { program, file := l.tryGetProgramAndFile(fileName) if file == nil { return nil, fmt.Errorf("%w: %s", ErrNoSourceFile, fileName) @@ -23,17 +24,21 @@ func (l *LanguageService) GetSymbolAtPosition(fileName string, position int) (*a if node == nil { return nil, fmt.Errorf("%w: %s:%d", ErrNoTokenAtPosition, fileName, position) } - checker := program.GetTypeChecker() + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() return checker.GetSymbolAtLocation(node), nil } -func (l *LanguageService) GetSymbolAtLocation(node *ast.Node) *ast.Symbol { +func (l *LanguageService) GetSymbolAtLocation(ctx context.Context, node *ast.Node) *ast.Symbol { program := l.GetProgram() - checker := program.GetTypeChecker() + checker, done := program.GetTypeCheckerForFile(ctx, ast.GetSourceFileOfNode(node)) + defer done() return checker.GetSymbolAtLocation(node) } -func (l *LanguageService) GetTypeOfSymbol(symbol *ast.Symbol) *checker.Type { - checker := l.GetProgram().GetTypeChecker() +func (l *LanguageService) GetTypeOfSymbol(ctx context.Context, symbol *ast.Symbol) *checker.Type { + program := l.GetProgram() + checker, done := program.GetTypeChecker(ctx) + defer done() return checker.GetTypeOfSymbolAtLocation(symbol, nil) } diff --git a/internal/ls/completions.go b/internal/ls/completions.go index 262688bda4..257b4b2d87 100644 --- a/internal/ls/completions.go +++ b/internal/ls/completions.go @@ -1,6 +1,7 @@ package ls import ( + "context" "fmt" "maps" "slices" @@ -22,14 +23,23 @@ import ( ) func (l *LanguageService) ProvideCompletion( - fileName string, - position int, + ctx context.Context, + documentURI lsproto.DocumentUri, + position lsproto.Position, context *lsproto.CompletionContext, clientOptions *lsproto.CompletionClientCapabilities, preferences *UserPreferences, -) *lsproto.CompletionList { - program, file := l.getProgramAndFile(fileName) - return l.getCompletionsAtPosition(program, file, position, context, preferences, clientOptions) +) (*lsproto.CompletionList, error) { + program, file := l.getProgramAndFile(documentURI) + return l.getCompletionsAtPosition( + ctx, + program, + file, + int(l.converters.LineAndCharacterToPosition(file, position)), + context, + preferences, + clientOptions, + ), nil } // *completionDataData | *completionDataKeyword @@ -257,6 +267,7 @@ const ( ) func (l *LanguageService) getCompletionsAtPosition( + ctx context.Context, program *compiler.Program, file *ast.SourceFile, position int, @@ -287,7 +298,9 @@ func (l *LanguageService) getCompletionsAtPosition( // !!! label completions - data := getCompletionData(program, file, position, preferences) + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() + data := getCompletionData(program, checker, file, position, preferences) if data == nil { return nil } @@ -295,6 +308,7 @@ func (l *LanguageService) getCompletionsAtPosition( switch data := data.(type) { case *completionDataData: response := l.completionInfoFromData( + ctx, file, program, compilerOptions, @@ -313,8 +327,7 @@ func (l *LanguageService) getCompletionsAtPosition( } } -func getCompletionData(program *compiler.Program, file *ast.SourceFile, position int, preferences *UserPreferences) completionData { - typeChecker := program.GetTypeChecker() +func getCompletionData(program *compiler.Program, typeChecker *checker.Checker, file *ast.SourceFile, position int, preferences *UserPreferences) completionData { inCheckedFile := isCheckedFile(file, program.GetCompilerOptions()) currentToken := astnav.GetTokenAtPosition(file, position) @@ -1451,6 +1464,7 @@ func getDefaultCommitCharacters(isNewIdentifierLocation bool) []string { } func (l *LanguageService) completionInfoFromData( + ctx context.Context, file *ast.SourceFile, program *compiler.Program, compilerOptions *core.CompilerOptions, @@ -1488,6 +1502,7 @@ func (l *LanguageService) completionInfoFromData( } uniqueNames, sortedEntries := l.getCompletionEntriesFromSymbols( + ctx, data, nil, /*replacementToken*/ position, @@ -1548,6 +1563,7 @@ func (l *LanguageService) completionInfoFromData( } func (l *LanguageService) getCompletionEntriesFromSymbols( + ctx context.Context, data *completionDataData, replacementToken *ast.Node, position int, @@ -1560,7 +1576,8 @@ func (l *LanguageService) getCompletionEntriesFromSymbols( ) (uniqueNames core.Set[string], sortedEntries []*lsproto.CompletionItem) { closestSymbolDeclaration := getClosestSymbolDeclaration(data.contextToken, data.location) useSemicolons := probablyUsesSemicolons(file) - typeChecker := program.GetTypeChecker() + typeChecker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() isMemberCompletion := isMemberCompletionKind(data.completionKind) optionalReplacementSpan := getOptionalReplacementSpan(data.location, file) // Tracks unique names. @@ -1602,6 +1619,7 @@ func (l *LanguageService) getCompletionEntriesFromSymbols( sortText = originalSortText } entry := l.createCompletionItem( + ctx, symbol, sortText, replacementToken, @@ -1669,6 +1687,7 @@ func createCompletionItemForLiteral( } func (l *LanguageService) createCompletionItem( + ctx context.Context, symbol *ast.Symbol, sortText sortText, replacementToken *ast.Node, @@ -1694,7 +1713,8 @@ func (l *LanguageService) createCompletionItem( source := getSourceFromOrigin(origin) var labelDetails *lsproto.CompletionItemLabelDetails - typeChecker := program.GetTypeChecker() + typeChecker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() insertQuestionDot := originIsNullableMember(origin) useBraces := originIsSymbolMember(origin) || needsConvertPropertyAccess if originIsThisType(origin) { diff --git a/internal/ls/completions_test.go b/internal/ls/completions_test.go index 5b78a4a700..3b9b1edc0d 100644 --- a/internal/ls/completions_test.go +++ b/internal/ls/completions_test.go @@ -1,6 +1,7 @@ package ls_test import ( + "context" "slices" "testing" @@ -1551,7 +1552,9 @@ func runTest(t *testing.T, files map[string]string, expected map[string]*testCas parsedFiles[fileName] = content } } - languageService := createLanguageService(mainFileName, parsedFiles) + ctx := projecttestutil.WithRequestID(t.Context()) + languageService, done := createLanguageService(ctx, mainFileName, parsedFiles) + defer done() context := &lsproto.CompletionContext{ TriggerKind: lsproto.CompletionTriggerKindInvoked, } @@ -1575,12 +1578,14 @@ func runTest(t *testing.T, files map[string]string, expected map[string]*testCas if !ok { t.Fatalf("No marker found for '%s'", markerName) } - completionList := languageService.ProvideCompletion( - mainFileName, - marker.Position, + completionList, err := languageService.ProvideCompletion( + ctx, + ls.FileNameToDocumentURI(mainFileName), + marker.LSPosition, context, capabilities, preferences) + assert.NilError(t, err) if expectedResult.isIncludes { assertIncludesItem(t, completionList, expectedResult.list) } else { @@ -1610,11 +1615,11 @@ func assertIncludesItem(t *testing.T, actual *lsproto.CompletionList, expected * return false } -func createLanguageService(fileName string, files map[string]string) *ls.LanguageService { +func createLanguageService(ctx context.Context, fileName string, files map[string]string) (*ls.LanguageService, func()) { projectService, _ := projecttestutil.Setup(files) projectService.OpenFile(fileName, files[fileName], core.ScriptKindTS, "") project := projectService.Projects()[0] - return project.LanguageService() + return project.GetLanguageServiceForRequest(ctx) } func ptrTo[T any](v T) *T { diff --git a/internal/ls/converters.go b/internal/ls/converters.go index 204832b263..690870df03 100644 --- a/internal/ls/converters.go +++ b/internal/ls/converters.go @@ -8,137 +8,60 @@ import ( "unicode/utf16" "unicode/utf8" - "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/core" - "github.com/microsoft/typescript-go/internal/diagnostics" "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -type ScriptInfo interface { - Text() string - LineMap() *LineMap -} - type Converters struct { - getScriptInfo func(fileName string) ScriptInfo + getLineMap func(fileName string) *LineMap positionEncoding lsproto.PositionEncodingKind } -func NewConverters(positionEncoding lsproto.PositionEncodingKind, getScriptInfo func(fileName string) ScriptInfo) *Converters { +type Script interface { + FileName() string + Text() string +} + +func NewConverters(positionEncoding lsproto.PositionEncodingKind, getLineMap func(fileName string) *LineMap) *Converters { return &Converters{ - getScriptInfo: getScriptInfo, + getLineMap: getLineMap, positionEncoding: positionEncoding, } } -func (c *Converters) ToLSPRange(fileName string, textRange core.TextRange) (lsproto.Range, error) { - scriptInfo := c.getScriptInfo(fileName) - if scriptInfo == nil { - return lsproto.Range{}, fmt.Errorf("no script info found for %s", fileName) - } - +func (c *Converters) ToLSPRange(script Script, textRange core.TextRange) lsproto.Range { return lsproto.Range{ - Start: c.PositionToLineAndCharacter(scriptInfo, core.TextPos(textRange.Pos())), - End: c.PositionToLineAndCharacter(scriptInfo, core.TextPos(textRange.End())), - }, nil + Start: c.PositionToLineAndCharacter(script, core.TextPos(textRange.Pos())), + End: c.PositionToLineAndCharacter(script, core.TextPos(textRange.End())), + } } -func (c *Converters) FromLSPRange(textRange lsproto.Range, fileName string) (core.TextRange, error) { - scriptInfo := c.getScriptInfo(fileName) - if scriptInfo == nil { - return core.TextRange{}, fmt.Errorf("no script info found for %s", fileName) - } +func (c *Converters) FromLSPRange(script Script, textRange lsproto.Range) core.TextRange { return core.NewTextRange( - int(c.LineAndCharacterToPosition(scriptInfo, textRange.Start)), - int(c.LineAndCharacterToPosition(scriptInfo, textRange.End)), - ), nil + int(c.LineAndCharacterToPosition(script, textRange.Start)), + int(c.LineAndCharacterToPosition(script, textRange.End)), + ) } -func (c *Converters) FromLSPTextChange(change *lsproto.TextDocumentContentChangePartial, fileName string) (TextChange, error) { - textRange, err := c.FromLSPRange(change.Range, fileName) - if err != nil { - return TextChange{}, fmt.Errorf("error converting range: %w", err) - } +func (c *Converters) FromLSPTextChange(script Script, change *lsproto.TextDocumentContentChangePartial) TextChange { return TextChange{ - TextRange: textRange, + TextRange: c.FromLSPRange(script, change.Range), NewText: change.Text, - }, nil -} - -func (c *Converters) ToLSPLocation(location Location) (lsproto.Location, error) { - rng, err := c.ToLSPRange(location.FileName, location.Range) - if err != nil { - return lsproto.Location{}, err } - return lsproto.Location{ - Uri: FileNameToDocumentURI(location.FileName), - Range: rng, - }, nil } -func (c *Converters) FromLSPLocation(location lsproto.Location) (Location, error) { - fileName := DocumentURIToFileName(location.Uri) - rng, err := c.FromLSPRange(location.Range, fileName) - if err != nil { - return Location{}, err - } - return Location{ - FileName: fileName, - Range: rng, - }, nil -} - -func (c *Converters) ToLSPDiagnostic(diagnostic *ast.Diagnostic) (*lsproto.Diagnostic, error) { - textRange, err := c.ToLSPRange(diagnostic.File().FileName(), diagnostic.Loc()) - if err != nil { - return nil, fmt.Errorf("error converting diagnostic range: %w", err) - } - - var severity lsproto.DiagnosticSeverity - switch diagnostic.Category() { - case diagnostics.CategorySuggestion: - severity = lsproto.DiagnosticSeverityHint - case diagnostics.CategoryMessage: - severity = lsproto.DiagnosticSeverityInformation - case diagnostics.CategoryWarning: - severity = lsproto.DiagnosticSeverityWarning - default: - severity = lsproto.DiagnosticSeverityError - } - - relatedInformation := make([]*lsproto.DiagnosticRelatedInformation, 0, len(diagnostic.RelatedInformation())) - for _, related := range diagnostic.RelatedInformation() { - relatedRange, err := c.ToLSPRange(related.File().FileName(), related.Loc()) - if err != nil { - return nil, fmt.Errorf("error converting related info range: %w", err) - } - relatedInformation = append(relatedInformation, &lsproto.DiagnosticRelatedInformation{ - Location: lsproto.Location{ - Uri: FileNameToDocumentURI(related.File().FileName()), - Range: relatedRange, - }, - Message: related.Message(), - }) +func (c *Converters) ToLSPLocation(script Script, rng core.TextRange) lsproto.Location { + return lsproto.Location{ + Uri: FileNameToDocumentURI(script.FileName()), + Range: c.ToLSPRange(script, rng), } - - return &lsproto.Diagnostic{ - Range: textRange, - Code: &lsproto.IntegerOrString{ - Integer: ptrTo(diagnostic.Code()), - }, - Severity: &severity, - Message: diagnostic.Message(), - Source: ptrTo("ts"), - RelatedInformation: &relatedInformation, - }, nil } -func (c *Converters) LineAndCharacterToPositionForFile(lineAndCharacter lsproto.Position, fileName string) (int, error) { - scriptInfo := c.getScriptInfo(fileName) - if scriptInfo == nil { - return 0, fmt.Errorf("no script info found for %s", fileName) +func (c *Converters) FromLSPLocation(script Script, rng lsproto.Range) Location { + return Location{ + FileName: script.FileName(), + Range: c.FromLSPRange(script, rng), } - return int(c.LineAndCharacterToPosition(scriptInfo, lineAndCharacter)), nil } func LanguageKindToScriptKind(languageID lsproto.LanguageKind) core.ScriptKind { @@ -202,10 +125,10 @@ func FileNameToDocumentURI(fileName string) lsproto.DocumentUri { return lsproto.DocumentUri("file://" + fileName) } -func (c *Converters) LineAndCharacterToPosition(scriptInfo ScriptInfo, lineAndCharacter lsproto.Position) core.TextPos { +func (c *Converters) LineAndCharacterToPosition(script Script, lineAndCharacter lsproto.Position) core.TextPos { // UTF-8/16 0-indexed line and character to UTF-8 offset - lineMap := scriptInfo.LineMap() + lineMap := c.getLineMap(script.FileName()) line := core.TextPos(lineAndCharacter.Line) char := core.TextPos(lineAndCharacter.Character) @@ -222,7 +145,7 @@ func (c *Converters) LineAndCharacterToPosition(scriptInfo ScriptInfo, lineAndCh var utf8Char core.TextPos var utf16Char core.TextPos - for i, r := range scriptInfo.Text()[start:] { + for i, r := range script.Text()[start:] { u16Len := core.TextPos(utf16.RuneLen(r)) if utf16Char+u16Len > char { break @@ -234,10 +157,10 @@ func (c *Converters) LineAndCharacterToPosition(scriptInfo ScriptInfo, lineAndCh return start + utf8Char } -func (c *Converters) PositionToLineAndCharacter(scriptInfo ScriptInfo, position core.TextPos) lsproto.Position { +func (c *Converters) PositionToLineAndCharacter(script Script, position core.TextPos) lsproto.Position { // UTF-8 offset to UTF-8/16 0-indexed line and character - lineMap := scriptInfo.LineMap() + lineMap := c.getLineMap(script.FileName()) line, isLineStart := slices.BinarySearch(lineMap.LineStarts, position) if !isLineStart { @@ -254,7 +177,7 @@ func (c *Converters) PositionToLineAndCharacter(scriptInfo ScriptInfo, position character = position - start } else { // We need to rescan the text as UTF-16 to find the character offset. - for _, r := range scriptInfo.Text()[start:position] { + for _, r := range script.Text()[start:position] { character += core.TextPos(utf16.RuneLen(r)) } } diff --git a/internal/ls/definition.go b/internal/ls/definition.go index f4ed52124f..6ccc155c2e 100644 --- a/internal/ls/definition.go +++ b/internal/ls/definition.go @@ -1,20 +1,25 @@ package ls import ( + "context" + "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/astnav" "github.com/microsoft/typescript-go/internal/core" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" "github.com/microsoft/typescript-go/internal/scanner" ) -func (l *LanguageService) ProvideDefinitions(fileName string, position int) []Location { - program, file := l.getProgramAndFile(fileName) - node := astnav.GetTouchingPropertyName(file, position) +func (l *LanguageService) ProvideDefinition(ctx context.Context, documentURI lsproto.DocumentUri, position lsproto.Position) (*lsproto.Definition, error) { + program, file := l.getProgramAndFile(documentURI) + node := astnav.GetTouchingPropertyName(file, int(l.converters.LineAndCharacterToPosition(file, position))) if node.Kind == ast.KindSourceFile { - return nil + return nil, nil } - checker := program.GetTypeChecker() + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() + if symbol := checker.GetSymbolAtLocation(node); symbol != nil { if symbol.Flags&ast.SymbolFlagsAlias != 0 { if resolved, ok := checker.ResolveAlias(symbol); ok { @@ -22,18 +27,17 @@ func (l *LanguageService) ProvideDefinitions(fileName string, position int) []Lo } } - locations := make([]Location, 0, len(symbol.Declarations)) + locations := make([]lsproto.Location, 0, len(symbol.Declarations)) for _, decl := range symbol.Declarations { file := ast.GetSourceFileOfNode(decl) loc := decl.Loc pos := scanner.GetTokenPosOfNode(decl, file, false /*includeJSDoc*/) - - locations = append(locations, Location{ - FileName: file.FileName(), - Range: core.NewTextRange(pos, loc.End()), + locations = append(locations, lsproto.Location{ + Uri: FileNameToDocumentURI(file.FileName()), + Range: l.converters.ToLSPRange(file, core.NewTextRange(pos, loc.End())), }) } - return locations + return &lsproto.Definition{Locations: &locations}, nil } - return nil + return nil, nil } diff --git a/internal/ls/diagnostics.go b/internal/ls/diagnostics.go index 1994ec740f..5bf4a1d1fd 100644 --- a/internal/ls/diagnostics.go +++ b/internal/ls/diagnostics.go @@ -2,14 +2,72 @@ package ls import ( "context" - "slices" "github.com/microsoft/typescript-go/internal/ast" + "github.com/microsoft/typescript-go/internal/diagnostics" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -func (l *LanguageService) GetDocumentDiagnostics(fileName string) []*ast.Diagnostic { - program, file := l.getProgramAndFile(fileName) - syntaxDiagnostics := program.GetSyntacticDiagnostics(context.Background(), file) - semanticDiagnostics := program.GetSemanticDiagnostics(context.Background(), file) - return slices.Concat(syntaxDiagnostics, semanticDiagnostics) +func (l *LanguageService) GetDocumentDiagnostics(ctx context.Context, documentURI lsproto.DocumentUri) (*lsproto.DocumentDiagnosticReport, error) { + program, file := l.getProgramAndFile(documentURI) + syntaxDiagnostics := program.GetSyntacticDiagnostics(ctx, file) + var lspDiagnostics []*lsproto.Diagnostic + if len(syntaxDiagnostics) != 0 { + lspDiagnostics = make([]*lsproto.Diagnostic, len(syntaxDiagnostics)) + for i, diag := range syntaxDiagnostics { + lspDiagnostics[i] = toLSPDiagnostic(diag, l.converters) + } + } else { + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() + semanticDiagnostics := checker.GetDiagnostics(ctx, file) + lspDiagnostics = make([]*lsproto.Diagnostic, len(semanticDiagnostics)) + for i, diag := range semanticDiagnostics { + lspDiagnostics[i] = toLSPDiagnostic(diag, l.converters) + } + } + return &lsproto.DocumentDiagnosticReport{ + RelatedFullDocumentDiagnosticReport: &lsproto.RelatedFullDocumentDiagnosticReport{ + FullDocumentDiagnosticReport: lsproto.FullDocumentDiagnosticReport{ + Kind: lsproto.StringLiteralFull{}, + Items: lspDiagnostics, + }, + }, + }, nil +} + +func toLSPDiagnostic(diagnostic *ast.Diagnostic, converters *Converters) *lsproto.Diagnostic { + var severity lsproto.DiagnosticSeverity + switch diagnostic.Category() { + case diagnostics.CategorySuggestion: + severity = lsproto.DiagnosticSeverityHint + case diagnostics.CategoryMessage: + severity = lsproto.DiagnosticSeverityInformation + case diagnostics.CategoryWarning: + severity = lsproto.DiagnosticSeverityWarning + default: + severity = lsproto.DiagnosticSeverityError + } + + relatedInformation := make([]*lsproto.DiagnosticRelatedInformation, 0, len(diagnostic.RelatedInformation())) + for _, related := range diagnostic.RelatedInformation() { + relatedInformation = append(relatedInformation, &lsproto.DiagnosticRelatedInformation{ + Location: lsproto.Location{ + Uri: FileNameToDocumentURI(related.File().FileName()), + Range: converters.ToLSPRange(related.File(), related.Loc()), + }, + Message: related.Message(), + }) + } + + return &lsproto.Diagnostic{ + Range: converters.ToLSPRange(diagnostic.File(), diagnostic.Loc()), + Code: &lsproto.IntegerOrString{ + Integer: ptrTo(diagnostic.Code()), + }, + Severity: &severity, + Message: diagnostic.Message(), + Source: ptrTo("ts"), + RelatedInformation: &relatedInformation, + } } diff --git a/internal/ls/host.go b/internal/ls/host.go index d54cd1b3ad..493596fc45 100644 --- a/internal/ls/host.go +++ b/internal/ls/host.go @@ -1,30 +1,12 @@ package ls import ( - "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/compiler" - "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/lsp/lsproto" - "github.com/microsoft/typescript-go/internal/tspath" - "github.com/microsoft/typescript-go/internal/vfs" ) type Host interface { - FS() vfs.FS - DefaultLibraryPath() string - GetCurrentDirectory() string - NewLine() string - Trace(msg string) - GetProjectVersion() int - // GetRootFileNames was called GetScriptFileNames in the original code. - GetRootFileNames() []string - // GetCompilerOptions was called GetCompilationSettings in the original code. - GetCompilerOptions() *core.CompilerOptions - GetSourceFile(fileName string, path tspath.Path, languageVersion core.ScriptTarget) *ast.SourceFile - // This responsibility was moved from the language service to the project, - // because they were bidirectionally interdependent. GetProgram() *compiler.Program - GetDefaultLibraryPath() string GetPositionEncoding() lsproto.PositionEncodingKind - GetScriptInfo(fileName string) ScriptInfo + GetLineMap(fileName string) *LineMap } diff --git a/internal/ls/hover.go b/internal/ls/hover.go index fdf5273f40..395148ba47 100644 --- a/internal/ls/hover.go +++ b/internal/ls/hover.go @@ -1,16 +1,56 @@ package ls import ( + "context" + "strings" + "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/astnav" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -func (l *LanguageService) ProvideHover(fileName string, position int) string { - program, file := l.getProgramAndFile(fileName) - node := astnav.GetTouchingPropertyName(file, position) +func (l *LanguageService) ProvideHover(ctx context.Context, documentURI lsproto.DocumentUri, position lsproto.Position) (*lsproto.Hover, error) { + program, file := l.getProgramAndFile(documentURI) + node := astnav.GetTouchingPropertyName(file, int(l.converters.LineAndCharacterToPosition(file, position))) if node.Kind == ast.KindSourceFile { // Avoid giving quickInfo for the sourceFile as a whole. + return nil, nil + } + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() + result := checker.GetQuickInfoAtLocation(node) + if result != "" { + return &lsproto.Hover{ + Contents: lsproto.MarkupContentOrMarkedStringOrMarkedStrings{ + MarkupContent: &lsproto.MarkupContent{ + Kind: lsproto.MarkupKindMarkdown, + Value: codeFence("typescript", result), + }, + }, + }, nil + } + return nil, nil +} + +func codeFence(lang string, code string) string { + if code == "" { return "" } - return program.GetTypeChecker().GetQuickInfoAtLocation(node) + ticks := 3 + for strings.Contains(code, strings.Repeat("`", ticks)) { + ticks++ + } + var result strings.Builder + result.Grow(len(code) + len(lang) + 2*ticks + 2) + for range ticks { + result.WriteByte('`') + } + result.WriteString(lang) + result.WriteByte('\n') + result.WriteString(code) + result.WriteByte('\n') + for range ticks { + result.WriteByte('`') + } + return result.String() } diff --git a/internal/ls/languageservice.go b/internal/ls/languageservice.go index 2fc19e4549..e441a2fee0 100644 --- a/internal/ls/languageservice.go +++ b/internal/ls/languageservice.go @@ -1,57 +1,26 @@ package ls import ( + "context" + "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/compiler" - "github.com/microsoft/typescript-go/internal/core" - "github.com/microsoft/typescript-go/internal/tspath" - "github.com/microsoft/typescript-go/internal/vfs" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -var _ compiler.CompilerHost = (*LanguageService)(nil) - type LanguageService struct { - converters *Converters + ctx context.Context host Host + converters *Converters } -func NewLanguageService(host Host) *LanguageService { +func NewLanguageService(ctx context.Context, host Host) *LanguageService { return &LanguageService{ host: host, - converters: NewConverters(host.GetPositionEncoding(), host.GetScriptInfo), + converters: NewConverters(host.GetPositionEncoding(), host.GetLineMap), } } -// FS implements compiler.CompilerHost. -func (l *LanguageService) FS() vfs.FS { - return l.host.FS() -} - -// DefaultLibraryPath implements compiler.CompilerHost. -func (l *LanguageService) DefaultLibraryPath() string { - return l.host.DefaultLibraryPath() -} - -// GetCurrentDirectory implements compiler.CompilerHost. -func (l *LanguageService) GetCurrentDirectory() string { - return l.host.GetCurrentDirectory() -} - -// NewLine implements compiler.CompilerHost. -func (l *LanguageService) NewLine() string { - return l.host.NewLine() -} - -// Trace implements compiler.CompilerHost. -func (l *LanguageService) Trace(msg string) { - l.host.Trace(msg) -} - -// GetSourceFile implements compiler.CompilerHost. -func (l *LanguageService) GetSourceFile(fileName string, path tspath.Path, languageVersion core.ScriptTarget) *ast.SourceFile { - return l.host.GetSourceFile(fileName, path, languageVersion) -} - // GetProgram updates the program if the project version has changed. func (l *LanguageService) GetProgram() *compiler.Program { return l.host.GetProgram() @@ -63,7 +32,8 @@ func (l *LanguageService) tryGetProgramAndFile(fileName string) (*compiler.Progr return program, file } -func (l *LanguageService) getProgramAndFile(fileName string) (*compiler.Program, *ast.SourceFile) { +func (l *LanguageService) getProgramAndFile(documentURI lsproto.DocumentUri) (*compiler.Program, *ast.SourceFile) { + fileName := DocumentURIToFileName(documentURI) program, file := l.tryGetProgramAndFile(fileName) if file == nil { panic("file not found: " + fileName) diff --git a/internal/ls/utilities.go b/internal/ls/utilities.go index 2f159d5639..02286e0b8c 100644 --- a/internal/ls/utilities.go +++ b/internal/ls/utilities.go @@ -268,10 +268,7 @@ func (l *LanguageService) createLspRangeFromNode(node *ast.Node, file *ast.Sourc } func (l *LanguageService) createLspRangeFromBounds(start, end int, file *ast.SourceFile) *lsproto.Range { - lspRange, err := l.converters.ToLSPRange(file.FileName(), core.NewTextRange(start, end)) - if err != nil { - panic(err) - } + lspRange := l.converters.ToLSPRange(file, core.NewTextRange(start, end)) return &lspRange } diff --git a/internal/lsp/lsproto/jsonrpc.go b/internal/lsp/lsproto/jsonrpc.go index e909e9442e..36e756fc7a 100644 --- a/internal/lsp/lsproto/jsonrpc.go +++ b/internal/lsp/lsproto/jsonrpc.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "strconv" ) type JSONRPCVersion struct{} @@ -28,10 +29,24 @@ type ID struct { int int32 } +func NewID(rawValue IntegerOrString) *ID { + if rawValue.String != nil { + return &ID{str: *rawValue.String} + } + return &ID{int: *rawValue.Integer} +} + func NewIDString(str string) *ID { return &ID{str: str} } +func (id *ID) String() string { + if id.str != "" { + return id.str + } + return strconv.Itoa(int(id.int)) +} + func (id *ID) MarshalJSON() ([]byte, error) { if id.str != "" { return json.Marshal(id.str) @@ -61,13 +76,92 @@ func (id *ID) MustInt() int32 { return id.int } -// TODO(jakebailey): NotificationMessage? Use RequestMessage without ID? +type MessageKind int + +const ( + MessageKindNotification MessageKind = iota + MessageKindRequest + MessageKindResponse +) + +type Message struct { + Kind MessageKind + msg any +} + +func (m *Message) AsRequest() *RequestMessage { + return m.msg.(*RequestMessage) +} + +func (m *Message) AsResponse() *ResponseMessage { + return m.msg.(*ResponseMessage) +} + +func (m *Message) UnmarshalJSON(data []byte) error { + var raw struct { + JSONRPC JSONRPCVersion `json:"jsonrpc"` + Method Method `json:"method"` + ID *ID `json:"id,omitempty"` + Params json.RawMessage `json:"params"` + Result any `json:"result,omitempty"` + Error *ResponseError `json:"error,omitempty"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("%w: %w", ErrInvalidRequest, err) + } + if raw.ID != nil && raw.Method == "" { + m.Kind = MessageKindResponse + m.msg = &ResponseMessage{ + JSONRPC: raw.JSONRPC, + ID: raw.ID, + Result: raw.Result, + Error: raw.Error, + } + return nil + } + + var params any + var err error + if len(raw.Params) > 0 { + params, err = unmarshalParams(raw.Method, raw.Params) + if err != nil { + return fmt.Errorf("%w: %w", ErrInvalidRequest, err) + } + } + + if raw.ID == nil { + m.Kind = MessageKindNotification + } else { + m.Kind = MessageKindRequest + } + + m.msg = &RequestMessage{ + JSONRPC: raw.JSONRPC, + ID: raw.ID, + Method: raw.Method, + Params: params, + } + + return nil +} + +func (m *Message) MarshalJSON() ([]byte, error) { + return json.Marshal(m.msg) +} + +func NewNotificationMessage(method Method, params any) *RequestMessage { + return &RequestMessage{ + JSONRPC: JSONRPCVersion{}, + Method: method, + Params: params, + } +} type RequestMessage struct { JSONRPC JSONRPCVersion `json:"jsonrpc"` ID *ID `json:"id,omitempty"` Method Method `json:"method"` - Params any `json:"params"` + Params any `json:"params,omitempty"` } func NewRequestMessage(method Method, id *ID, params any) *RequestMessage { @@ -78,6 +172,13 @@ func NewRequestMessage(method Method, id *ID, params any) *RequestMessage { } } +func (r *RequestMessage) Message() *Message { + return &Message{ + Kind: MessageKindRequest, + msg: r, + } +} + func (r *RequestMessage) UnmarshalJSON(data []byte) error { var raw struct { JSONRPC JSONRPCVersion `json:"jsonrpc"` @@ -108,8 +209,26 @@ type ResponseMessage struct { Error *ResponseError `json:"error,omitempty"` } +func (r *ResponseMessage) Message() *Message { + return &Message{ + Kind: MessageKindResponse, + msg: r, + } +} + type ResponseError struct { Code int32 `json:"code"` Message string `json:"message"` Data any `json:"data,omitempty"` } + +func (r *ResponseError) String() string { + if r == nil { + return "" + } + data, err := json.Marshal(r.Data) + if err != nil { + return fmt.Sprintf("[%d]: %s\n%v", r.Code, r.Message, data) + } + return fmt.Sprintf("[%d]: %s", r.Code, r.Message) +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 2931f40178..2c9ff587a6 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -1,20 +1,24 @@ package lsp import ( + "context" "encoding/json" "errors" "fmt" "io" + "os" + "os/signal" "runtime/debug" "slices" - "strings" - "time" + "sync" + "syscall" "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/ls" "github.com/microsoft/typescript-go/internal/lsp/lsproto" "github.com/microsoft/typescript-go/internal/project" "github.com/microsoft/typescript-go/internal/vfs" + "golang.org/x/sync/errgroup" ) type ServerOptions struct { @@ -33,13 +37,17 @@ func NewServer(opts *ServerOptions) *Server { panic("Cwd is required") } return &Server{ - r: lsproto.NewBaseReader(opts.In), - w: lsproto.NewBaseWriter(opts.Out), - stderr: opts.Err, - cwd: opts.Cwd, - newLine: opts.NewLine, - fs: opts.FS, - defaultLibraryPath: opts.DefaultLibraryPath, + r: lsproto.NewBaseReader(opts.In), + w: lsproto.NewBaseWriter(opts.Out), + stderr: opts.Err, + requestQueue: make(chan *lsproto.RequestMessage, 100), + outgoingQueue: make(chan *lsproto.Message, 100), + pendingClientRequests: make(map[lsproto.ID]pendingClientRequest), + pendingServerRequests: make(map[lsproto.ID]chan *lsproto.ResponseMessage), + cwd: opts.Cwd, + newLine: opts.NewLine, + fs: opts.FS, + defaultLibraryPath: opts.DefaultLibraryPath, } } @@ -48,15 +56,24 @@ var ( _ project.Client = (*Server)(nil) ) +type pendingClientRequest struct { + req *lsproto.RequestMessage + cancel context.CancelFunc +} + type Server struct { r *lsproto.BaseReader w *lsproto.BaseWriter stderr io.Writer - clientSeq int32 - requestMethod string - requestTime time.Time + clientSeq int32 + requestQueue chan *lsproto.RequestMessage + outgoingQueue chan *lsproto.Message + pendingClientRequests map[lsproto.ID]pendingClientRequest + pendingClientRequestsMu sync.Mutex + pendingServerRequests map[lsproto.ID]chan *lsproto.ResponseMessage + pendingServerRequestsMu sync.Mutex cwd string newLine core.NewLineKind @@ -71,7 +88,6 @@ type Server struct { watchers core.Set[project.WatcherHandle] logger *project.Logger projectService *project.Service - converters *ls.Converters } // FS implements project.ServiceHost. @@ -108,9 +124,9 @@ func (s *Server) Client() project.Client { } // WatchFiles implements project.Client. -func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { +func (s *Server) WatchFiles(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { watcherId := fmt.Sprintf("watcher-%d", s.watcherID) - if err := s.sendRequest(lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ + _, err := s.sendRequest(ctx, lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ Registrations: []*lsproto.Registration{ { Id: watcherId, @@ -120,7 +136,8 @@ func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.Watc })), }, }, - }); err != nil { + }) + if err != nil { return "", fmt.Errorf("failed to register file watcher: %w", err) } @@ -131,18 +148,20 @@ func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.Watc } // UnwatchFiles implements project.Client. -func (s *Server) UnwatchFiles(handle project.WatcherHandle) error { +func (s *Server) UnwatchFiles(ctx context.Context, handle project.WatcherHandle) error { if s.watchers.Has(handle) { - if err := s.sendRequest(lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{ + _, err := s.sendRequest(ctx, lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{ Unregisterations: []*lsproto.Unregistration{ { Id: string(handle), Method: string(lsproto.MethodWorkspaceDidChangeWatchedFiles), }, }, - }); err != nil { + }) + if err != nil { return fmt.Errorf("failed to unregister file watcher: %w", err) } + s.watchers.Delete(handle) return nil } @@ -151,9 +170,9 @@ func (s *Server) UnwatchFiles(handle project.WatcherHandle) error { } // RefreshDiagnostics implements project.Client. -func (s *Server) RefreshDiagnostics() error { +func (s *Server) RefreshDiagnostics(ctx context.Context) error { if ptrIsTrue(s.initializeParams.Capabilities.Workspace.Diagnostics.RefreshSupport) { - if err := s.sendRequest(lsproto.MethodWorkspaceDiagnosticRefresh, nil); err != nil { + if _, err := s.sendRequest(ctx, lsproto.MethodWorkspaceDiagnosticRefresh, nil); err != nil { return fmt.Errorf("failed to refresh diagnostics: %w", err) } } @@ -161,95 +180,189 @@ func (s *Server) RefreshDiagnostics() error { } func (s *Server) Run() error { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { return s.dispatchLoop(ctx) }) + g.Go(func() error { return s.writeLoop(ctx) }) + g.Go(func() error { return s.readLoop(ctx) }) + return g.Wait() +} + +func (s *Server) readLoop(ctx context.Context) error { for { - req, err := s.read() + msg, err := s.read() if err != nil { + if errors.Is(err, io.EOF) { + return nil + } if errors.Is(err, lsproto.ErrInvalidRequest) { - if err := s.sendError(nil, err); err != nil { - return err - } + s.sendError(nil, err) continue } return err } - // TODO: handle response messages - if req == nil { - continue - } - - if s.initializeParams == nil { + if s.initializeParams == nil && msg.Kind == lsproto.MessageKindRequest { + req := msg.AsRequest() if req.Method == lsproto.MethodInitialize { - if err := s.handleInitialize(req); err != nil { - return err - } + s.handleInitialize(req) } else { - if err := s.sendError(req.ID, lsproto.ErrServerNotInitialized); err != nil { - return err - } + s.sendError(req.ID, lsproto.ErrServerNotInitialized) } continue } - if err := s.handleMessage(req); err != nil { - return err + if msg.Kind == lsproto.MessageKindResponse { + resp := msg.AsResponse() + s.pendingServerRequestsMu.Lock() + if respChan, ok := s.pendingServerRequests[*resp.ID]; ok { + respChan <- resp + close(respChan) + delete(s.pendingServerRequests, *resp.ID) + } + s.pendingServerRequestsMu.Unlock() + } else { + req := msg.AsRequest() + if req.Method == lsproto.MethodCancelRequest { + s.cancelRequest(req.Params.(*lsproto.CancelParams).Id) + } else { + s.requestQueue <- req + } } } } -func (s *Server) read() (*lsproto.RequestMessage, error) { +func (s *Server) cancelRequest(rawID lsproto.IntegerOrString) { + id := lsproto.NewID(rawID) + s.pendingClientRequestsMu.Lock() + defer s.pendingClientRequestsMu.Unlock() + if pendingReq, ok := s.pendingClientRequests[*id]; ok { + pendingReq.cancel() + delete(s.pendingClientRequests, *id) + } +} + +func (s *Server) read() (*lsproto.Message, error) { data, err := s.r.Read() if err != nil { return nil, err } - req := &lsproto.RequestMessage{} + req := &lsproto.Message{} if err := json.Unmarshal(data, req); err != nil { - res := &lsproto.ResponseMessage{} - if err = json.Unmarshal(data, res); err == nil { - // !!! TODO: handle response - return nil, nil - } return nil, fmt.Errorf("%w: %w", lsproto.ErrInvalidRequest, err) } return req, nil } -func (s *Server) sendRequest(method lsproto.Method, params any) error { +func (s *Server) dispatchLoop(ctx context.Context) error { + ctx, lspExit := context.WithCancel(ctx) + defer lspExit() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-s.requestQueue: + requestCtx := ctx + if req.ID != nil { + var cancel context.CancelFunc + requestCtx, cancel = context.WithCancel(core.WithRequestID(requestCtx, req.ID.String())) + s.pendingClientRequestsMu.Lock() + s.pendingClientRequests[*req.ID] = pendingClientRequest{ + req: req, + cancel: cancel, + } + s.pendingClientRequestsMu.Unlock() + } + + handle := func() { + if err := s.handleRequestOrNotification(requestCtx, req); err != nil { + if errors.Is(err, io.EOF) { + lspExit() + } else { + s.sendError(req.ID, err) + } + } + + if req.ID != nil { + s.pendingClientRequestsMu.Lock() + delete(s.pendingClientRequests, *req.ID) + s.pendingClientRequestsMu.Unlock() + } + } + + if isBlockingMethod(req.Method) { + handle() + } else { + go handle() + } + } + } +} + +func (s *Server) writeLoop(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case msg := <-s.outgoingQueue: + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + if err := s.w.Write(data); err != nil { + return fmt.Errorf("failed to write message: %w", err) + } + } + } +} + +func (s *Server) sendRequest(ctx context.Context, method lsproto.Method, params any) (any, error) { s.clientSeq++ id := lsproto.NewIDString(fmt.Sprintf("ts%d", s.clientSeq)) req := lsproto.NewRequestMessage(method, id, params) - data, err := json.Marshal(req) - if err != nil { - return err - } - return s.w.Write(data) -} -func (s *Server) sendNotification(method lsproto.Method, params any) error { - req := lsproto.NewRequestMessage(method, nil /*id*/, params) - data, err := json.Marshal(req) - if err != nil { - return err + responseChan := make(chan *lsproto.ResponseMessage, 1) + s.pendingServerRequestsMu.Lock() + s.pendingServerRequests[*id] = responseChan + s.pendingServerRequestsMu.Unlock() + + s.outgoingQueue <- req.Message() + + select { + case <-ctx.Done(): + s.pendingServerRequestsMu.Lock() + defer s.pendingServerRequestsMu.Unlock() + if respChan, ok := s.pendingServerRequests[*id]; ok { + close(respChan) + delete(s.pendingServerRequests, *id) + } + return nil, ctx.Err() + case resp := <-responseChan: + if resp.Error != nil { + return nil, fmt.Errorf("request failed: %s", resp.Error.String()) + } + return resp.Result, nil } - return s.w.Write(data) } -func (s *Server) sendResult(id *lsproto.ID, result any) error { - return s.sendResponse(&lsproto.ResponseMessage{ +func (s *Server) sendResult(id *lsproto.ID, result any) { + s.sendResponse(&lsproto.ResponseMessage{ ID: id, Result: result, }) } -func (s *Server) sendError(id *lsproto.ID, err error) error { +func (s *Server) sendError(id *lsproto.ID, err error) { code := lsproto.ErrInternalError.Code if errCode := (*lsproto.ErrorCode)(nil); errors.As(err, &errCode) { code = errCode.Code } // TODO(jakebailey): error data - return s.sendResponse(&lsproto.ResponseMessage{ + s.sendResponse(&lsproto.ResponseMessage{ ID: id, Error: &lsproto.ResponseError{ Code: code, @@ -258,63 +371,55 @@ func (s *Server) sendError(id *lsproto.ID, err error) error { }) } -func (s *Server) sendResponse(resp *lsproto.ResponseMessage) error { - if !s.requestTime.IsZero() { - s.logger.PerfTrace(fmt.Sprintf("%s: %s", s.requestMethod, time.Since(s.requestTime))) - } - data, err := json.Marshal(resp) - if err != nil { - return err - } - return s.w.Write(data) +func (s *Server) sendResponse(resp *lsproto.ResponseMessage) { + s.outgoingQueue <- resp.Message() } -func (s *Server) handleMessage(req *lsproto.RequestMessage) error { - s.requestTime = time.Now() - s.requestMethod = string(req.Method) - +func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params switch params.(type) { case *lsproto.InitializeParams: - return s.sendError(req.ID, lsproto.ErrInvalidRequest) + s.sendError(req.ID, lsproto.ErrInvalidRequest) + return nil case *lsproto.InitializedParams: - return s.handleInitialized(req) + return s.handleInitialized(ctx, req) case *lsproto.DidOpenTextDocumentParams: - return s.handleDidOpen(req) + return s.handleDidOpen(ctx, req) case *lsproto.DidChangeTextDocumentParams: - return s.handleDidChange(req) + return s.handleDidChange(ctx, req) case *lsproto.DidSaveTextDocumentParams: - return s.handleDidSave(req) + return s.handleDidSave(ctx, req) case *lsproto.DidCloseTextDocumentParams: - return s.handleDidClose(req) + return s.handleDidClose(ctx, req) case *lsproto.DidChangeWatchedFilesParams: - return s.handleDidChangeWatchedFiles(req) + return s.handleDidChangeWatchedFiles(ctx, req) case *lsproto.DocumentDiagnosticParams: - return s.handleDocumentDiagnostic(req) + return s.handleDocumentDiagnostic(ctx, req) case *lsproto.HoverParams: - return s.handleHover(req) + return s.handleHover(ctx, req) case *lsproto.DefinitionParams: - return s.handleDefinition(req) + return s.handleDefinition(ctx, req) case *lsproto.CompletionParams: - return s.handleCompletion(req) + return s.handleCompletion(ctx, req) default: switch req.Method { case lsproto.MethodShutdown: s.projectService.Close() - return s.sendResult(req.ID, nil) - case lsproto.MethodExit: + s.sendResult(req.ID, nil) return nil + case lsproto.MethodExit: + return io.EOF default: s.Log("unknown method", req.Method) if req.ID != nil { - return s.sendError(req.ID, lsproto.ErrInvalidRequest) + s.sendError(req.ID, lsproto.ErrInvalidRequest) } return nil } } } -func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { +func (s *Server) handleInitialize(req *lsproto.RequestMessage) { s.initializeParams = req.Params.(*lsproto.InitializeParams) s.positionEncoding = lsproto.PositionEncodingKindUTF16 @@ -324,7 +429,7 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { } } - return s.sendResult(req.ID, &lsproto.InitializeResult{ + s.sendResult(req.ID, &lsproto.InitializeResult{ ServerInfo: &lsproto.ServerInfo{ Name: "typescript-go", Version: ptrTo(core.Version), @@ -361,7 +466,7 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { }) } -func (s *Server) handleInitialized(req *lsproto.RequestMessage) error { +func (s *Server) handleInitialized(ctx context.Context, req *lsproto.RequestMessage) error { if s.initializeParams.Capabilities.Workspace.DidChangeWatchedFiles != nil && *s.initializeParams.Capabilities.Workspace.DidChangeWatchedFiles.DynamicRegistration { s.watchEnabled = true } @@ -373,183 +478,120 @@ func (s *Server) handleInitialized(req *lsproto.RequestMessage) error { PositionEncoding: s.positionEncoding, }) - s.converters = ls.NewConverters(s.positionEncoding, func(fileName string) ls.ScriptInfo { - return s.projectService.GetScriptInfo(fileName) - }) - return nil } -func (s *Server) handleDidOpen(req *lsproto.RequestMessage) error { +func (s *Server) handleDidOpen(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidOpenTextDocumentParams) s.projectService.OpenFile(ls.DocumentURIToFileName(params.TextDocument.Uri), params.TextDocument.Text, ls.LanguageKindToScriptKind(params.TextDocument.LanguageId), "") return nil } -func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { +func (s *Server) handleDidChange(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidChangeTextDocumentParams) - scriptInfo := s.projectService.GetScriptInfo(ls.DocumentURIToFileName(params.TextDocument.Uri)) - if scriptInfo == nil { - return s.sendError(req.ID, lsproto.ErrRequestFailed) - } - - changes := make([]ls.TextChange, len(params.ContentChanges)) - for i, change := range params.ContentChanges { - if partialChange := change.TextDocumentContentChangePartial; partialChange != nil { - if textChange, err := s.converters.FromLSPTextChange(partialChange, scriptInfo.FileName()); err != nil { - return s.sendError(req.ID, err) - } else { - changes[i] = textChange - } - } else if wholeChange := change.TextDocumentContentChangeWholeDocument; wholeChange != nil { - changes[i] = ls.TextChange{ - TextRange: core.NewTextRange(0, len(scriptInfo.Text())), - NewText: wholeChange.Text, - } - } else { - return s.sendError(req.ID, lsproto.ErrInvalidRequest) - } - } - - s.projectService.ChangeFile(ls.DocumentURIToFileName(params.TextDocument.Uri), changes) - return nil + return s.projectService.ChangeFile(params.TextDocument, params.ContentChanges) } -func (s *Server) handleDidSave(req *lsproto.RequestMessage) error { +func (s *Server) handleDidSave(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidSaveTextDocumentParams) s.projectService.MarkFileSaved(ls.DocumentURIToFileName(params.TextDocument.Uri), *params.Text) return nil } -func (s *Server) handleDidClose(req *lsproto.RequestMessage) error { +func (s *Server) handleDidClose(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidCloseTextDocumentParams) s.projectService.CloseFile(ls.DocumentURIToFileName(params.TextDocument.Uri)) return nil } -func (s *Server) handleDidChangeWatchedFiles(req *lsproto.RequestMessage) error { +func (s *Server) handleDidChangeWatchedFiles(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidChangeWatchedFilesParams) - return s.projectService.OnWatchedFilesChanged(params.Changes) + return s.projectService.OnWatchedFilesChanged(ctx, params.Changes) } -func (s *Server) handleDocumentDiagnostic(req *lsproto.RequestMessage) error { +func (s *Server) handleDocumentDiagnostic(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DocumentDiagnosticParams) - file, project := s.getFileAndProject(params.TextDocument.Uri) - diagnostics := project.LanguageService().GetDocumentDiagnostics(file.FileName()) - lspDiagnostics := make([]*lsproto.Diagnostic, len(diagnostics)) - for i, diag := range diagnostics { - if lspDiagnostic, err := s.converters.ToLSPDiagnostic(diag); err != nil { - return s.sendError(req.ID, err) - } else { - lspDiagnostics[i] = lspDiagnostic - } + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + diagnostics, err := languageService.GetDocumentDiagnostics(ctx, params.TextDocument.Uri) + if err != nil { + return err } - return s.sendResult(req.ID, &lsproto.DocumentDiagnosticReport{ - RelatedFullDocumentDiagnosticReport: &lsproto.RelatedFullDocumentDiagnosticReport{ - FullDocumentDiagnosticReport: lsproto.FullDocumentDiagnosticReport{ - Kind: lsproto.StringLiteralFull{}, - Items: lspDiagnostics, - }, - }, - }) + s.sendResult(req.ID, diagnostics) + return nil } -func (s *Server) handleHover(req *lsproto.RequestMessage) error { +func (s *Server) handleHover(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.HoverParams) - file, project := s.getFileAndProject(params.TextDocument.Uri) - pos, err := s.converters.LineAndCharacterToPositionForFile(params.Position, file.FileName()) + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + hover, err := languageService.ProvideHover(ctx, params.TextDocument.Uri, params.Position) if err != nil { - return s.sendError(req.ID, err) + return err } - - hoverText := project.LanguageService().ProvideHover(file.FileName(), pos) - return s.sendResult(req.ID, &lsproto.Hover{ - Contents: lsproto.MarkupContentOrMarkedStringOrMarkedStrings{ - MarkupContent: &lsproto.MarkupContent{ - Kind: lsproto.MarkupKindMarkdown, - Value: codeFence("ts", hoverText), - }, - }, - }) + s.sendResult(req.ID, hover) + return nil } -func (s *Server) handleDefinition(req *lsproto.RequestMessage) error { +func (s *Server) handleDefinition(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DefinitionParams) - file, project := s.getFileAndProject(params.TextDocument.Uri) - pos, err := s.converters.LineAndCharacterToPositionForFile(params.Position, file.FileName()) + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + definition, err := languageService.ProvideDefinition(ctx, params.TextDocument.Uri, params.Position) if err != nil { - return s.sendError(req.ID, err) - } - - locations := project.LanguageService().ProvideDefinitions(file.FileName(), pos) - lspLocations := make([]lsproto.Location, len(locations)) - for i, loc := range locations { - if lspLocation, err := s.converters.ToLSPLocation(loc); err != nil { - return s.sendError(req.ID, err) - } else { - lspLocations[i] = lspLocation - } + return err } - - return s.sendResult(req.ID, &lsproto.Definition{Locations: &lspLocations}) + s.sendResult(req.ID, definition) + return nil } -func (s *Server) handleCompletion(req *lsproto.RequestMessage) (messageErr error) { +func (s *Server) handleCompletion(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.CompletionParams) - file, project := s.getFileAndProject(params.TextDocument.Uri) - pos, err := s.converters.LineAndCharacterToPositionForFile(params.Position, file.FileName()) - if err != nil { - return s.sendError(req.ID, err) - } - + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() // !!! remove this after completions is fully ported/tested defer func() { if r := recover(); r != nil { stack := debug.Stack() s.Log("panic obtaining completions:", r, string(stack)) - messageErr = s.sendResult(req.ID, &lsproto.CompletionList{}) + s.sendResult(req.ID, &lsproto.CompletionList{}) } }() // !!! get user preferences - list := project.LanguageService().ProvideCompletion( - file.FileName(), - pos, + list, err := languageService.ProvideCompletion( + ctx, + params.TextDocument.Uri, + params.Position, params.Context, s.initializeParams.Capabilities.TextDocument.Completion, &ls.UserPreferences{}) - return s.sendResult(req.ID, list) -} - -func (s *Server) getFileAndProject(uri lsproto.DocumentUri) (*project.ScriptInfo, *project.Project) { - fileName := ls.DocumentURIToFileName(uri) - return s.projectService.EnsureDefaultProjectForFile(fileName) + if err != nil { + return err + } + s.sendResult(req.ID, list) + return nil } func (s *Server) Log(msg ...any) { fmt.Fprintln(s.stderr, msg...) } -func codeFence(lang string, code string) string { - if code == "" { - return "" - } - ticks := 3 - for strings.Contains(code, strings.Repeat("`", ticks)) { - ticks++ - } - var result strings.Builder - result.Grow(len(code) + len(lang) + 2*ticks + 2) - for range ticks { - result.WriteByte('`') - } - result.WriteString(lang) - result.WriteByte('\n') - result.WriteString(code) - result.WriteByte('\n') - for range ticks { - result.WriteByte('`') +func isBlockingMethod(method lsproto.Method) bool { + switch method { + case lsproto.MethodInitialize, + lsproto.MethodInitialized, + lsproto.MethodTextDocumentDidOpen, + lsproto.MethodTextDocumentDidChange, + lsproto.MethodTextDocumentDidSave, + lsproto.MethodTextDocumentDidClose, + lsproto.MethodWorkspaceDidChangeWatchedFiles: + return true } - return result.String() + return false } func ptrTo[T any](v T) *T { diff --git a/internal/project/checkerpool.go b/internal/project/checkerpool.go new file mode 100644 index 0000000000..2d1f4edd23 --- /dev/null +++ b/internal/project/checkerpool.go @@ -0,0 +1,233 @@ +package project + +import ( + "context" + "fmt" + "iter" + "sync" + + "github.com/microsoft/typescript-go/internal/ast" + "github.com/microsoft/typescript-go/internal/checker" + "github.com/microsoft/typescript-go/internal/compiler" + "github.com/microsoft/typescript-go/internal/core" +) + +type checkerPool struct { + maxCheckers int + program *compiler.Program + + mu sync.Mutex + cond *sync.Cond + createCheckersOnce sync.Once + checkers []*checker.Checker + inUse map[*checker.Checker]bool + fileAssociations map[*ast.SourceFile]int + requestAssociations map[string]int + log func(msg string) +} + +var _ compiler.CheckerPool = (*checkerPool)(nil) + +func newCheckerPool(maxCheckers int, program *compiler.Program, log func(msg string)) *checkerPool { + pool := &checkerPool{ + program: program, + maxCheckers: maxCheckers, + checkers: make([]*checker.Checker, maxCheckers), + inUse: make(map[*checker.Checker]bool), + requestAssociations: make(map[string]int), + log: log, + } + + pool.cond = sync.NewCond(&pool.mu) + return pool +} + +func (p *checkerPool) GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) { + p.mu.Lock() + defer p.mu.Unlock() + + requestID := core.GetRequestID(ctx) + if requestID != "" { + if checker, release := p.getRequestCheckerLocked(requestID); checker != nil { + return checker, release + } + } + + if p.fileAssociations == nil { + p.fileAssociations = make(map[*ast.SourceFile]int) + } + + if index, ok := p.fileAssociations[file]; ok { + checker := p.checkers[index] + if checker != nil { + if inUse := p.inUse[checker]; !inUse { + p.inUse[checker] = true + if requestID != "" { + p.requestAssociations[requestID] = index + } + return checker, p.createRelease(requestID, index, checker) + } + } + } + + checker, index := p.getCheckerLocked(requestID) + p.fileAssociations[file] = index + return checker, p.createRelease(requestID, index, checker) +} + +func (p *checkerPool) GetChecker(ctx context.Context) (*checker.Checker, func()) { + p.mu.Lock() + defer p.mu.Unlock() + checker, index := p.getCheckerLocked(core.GetRequestID(ctx)) + return checker, p.createRelease(core.GetRequestID(ctx), index, checker) +} + +func (p *checkerPool) Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] { + panic("unimplemented") +} + +func (p *checkerPool) GetAllCheckers(ctx context.Context) ([]*checker.Checker, func()) { + requestID := core.GetRequestID(ctx) + if requestID == "" { + panic("cannot call GetAllCheckers on a project.checkerPool without a request ID") + } + + // A request can only access one checker + if c, release := p.getRequestCheckerLocked(requestID); c != nil { + return []*checker.Checker{c}, release + } + + c, release := p.GetChecker(ctx) + return []*checker.Checker{c}, release +} + +func (p *checkerPool) getCheckerLocked(requestID string) (*checker.Checker, int) { + if checker, index := p.getImmediatelyAvailableChecker(); checker != nil { + p.inUse[checker] = true + if requestID != "" { + p.requestAssociations[requestID] = index + } + return checker, index + } + + if !p.isFullLocked() { + checker, index := p.createCheckerLocked() + p.inUse[checker] = true + if requestID != "" { + p.requestAssociations[requestID] = index + } + return checker, index + } + + checker, index := p.waitForAvailableChecker() + p.inUse[checker] = true + if requestID != "" { + p.requestAssociations[requestID] = index + } + return checker, index +} + +func (p *checkerPool) getRequestCheckerLocked(requestID string) (*checker.Checker, func()) { + if index, ok := p.requestAssociations[requestID]; ok { + checker := p.checkers[index] + if checker != nil { + if inUse := p.inUse[checker]; !inUse { + p.inUse[checker] = true + return checker, p.createRelease(requestID, index, checker) + } + // Checker is in use, but by the same request - assume it's the + // same goroutine or is managing its own synchronization + return checker, noop + } + } + return nil, noop +} + +func (p *checkerPool) getImmediatelyAvailableChecker() (*checker.Checker, int) { + for i, checker := range p.checkers { + if checker == nil { + continue + } + if inUse := p.inUse[checker]; !inUse { + return checker, i + } + } + + return nil, -1 +} + +func (p *checkerPool) waitForAvailableChecker() (*checker.Checker, int) { + p.log("checkerpool: Waiting for an available checker") + for { + p.cond.Wait() + checker, index := p.getImmediatelyAvailableChecker() + if checker != nil { + return checker, index + } + } +} + +func (p *checkerPool) createRelease(requestId string, index int, checker *checker.Checker) func() { + return func() { + p.mu.Lock() + defer p.mu.Unlock() + + delete(p.requestAssociations, requestId) + if checker.WasCanceled() { + // Canceled checkers must be disposed + p.log(fmt.Sprintf("checkerpool: Checker for request %s was canceled, disposing it", requestId)) + p.checkers[index] = nil + delete(p.inUse, checker) + } else { + p.inUse[checker] = false + } + p.cond.Signal() + } +} + +func (p *checkerPool) isFullLocked() bool { + for _, checker := range p.checkers { + if checker == nil { + return false + } + } + return true +} + +func (p *checkerPool) createCheckerLocked() (*checker.Checker, int) { + for i, existing := range p.checkers { + if existing == nil { + checker := checker.NewChecker(p.program) + p.checkers[i] = checker + return checker, i + } + } + panic("called createCheckerLocked when pool is full") +} + +func (p *checkerPool) isRequestCheckerInUse(requestID string) bool { + p.mu.Lock() + defer p.mu.Unlock() + + if index, ok := p.requestAssociations[requestID]; ok { + checker := p.checkers[index] + if checker != nil { + return p.inUse[checker] + } + } + return false +} + +func (p *checkerPool) size() int { + p.mu.Lock() + defer p.mu.Unlock() + size := 0 + for _, checker := range p.checkers { + if checker != nil { + size++ + } + } + return size +} + +func noop() {} diff --git a/internal/project/host.go b/internal/project/host.go index b9e9635e26..18bad0b808 100644 --- a/internal/project/host.go +++ b/internal/project/host.go @@ -1,6 +1,8 @@ package project import ( + "context" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" "github.com/microsoft/typescript-go/internal/vfs" ) @@ -8,9 +10,9 @@ import ( type WatcherHandle string type Client interface { - WatchFiles(watchers []*lsproto.FileSystemWatcher) (WatcherHandle, error) - UnwatchFiles(handle WatcherHandle) error - RefreshDiagnostics() error + WatchFiles(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (WatcherHandle, error) + UnwatchFiles(ctx context.Context, handle WatcherHandle) error + RefreshDiagnostics(ctx context.Context) error } type ServiceHost interface { diff --git a/internal/project/project.go b/internal/project/project.go index ff49305911..d5ff54a7e2 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -1,6 +1,7 @@ package project import ( + "context" "fmt" "maps" "slices" @@ -23,8 +24,6 @@ const hr = "-----------------------------------------------" var projectNamer = &namer{} -var _ ls.Host = (*Project)(nil) - type Kind int const ( @@ -34,6 +33,34 @@ const ( KindAuxiliary ) +type snapshot struct { + project *Project + positionEncoding lsproto.PositionEncodingKind + program *compiler.Program +} + +// GetLineMap implements ls.Host. +func (s *snapshot) GetLineMap(fileName string) *ls.LineMap { + file := s.program.GetSourceFile(fileName) + scriptInfo := s.project.host.GetScriptInfoByPath(file.Path()) + if file.Version == scriptInfo.Version() { + return scriptInfo.LineMap() + } + return ls.ComputeLineStarts(file.Text()) +} + +// GetPositionEncoding implements ls.Host. +func (s *snapshot) GetPositionEncoding() lsproto.PositionEncodingKind { + return s.positionEncoding +} + +// GetProgram implements ls.Host. +func (s *snapshot) GetProgram() *compiler.Program { + return s.program +} + +var _ ls.Host = (*snapshot)(nil) + type PendingReload int const ( @@ -57,13 +84,15 @@ type ProjectHost interface { Client() Client } +var _ compiler.CompilerHost = (*Project)(nil) + type Project struct { host ProjectHost - mu sync.Mutex name string kind Kind + dirtyStateMu sync.Mutex initialLoadPending bool dirty bool version int @@ -84,8 +113,9 @@ type Project struct { rootFileNames *collections.OrderedMap[tspath.Path, string] compilerOptions *core.CompilerOptions parsedCommandLine *tsoptions.ParsedCommandLine - languageService *ls.LanguageService + programMu sync.Mutex program *compiler.Program + checkerPool *checkerPool // Watchers rootFilesWatch *watchedFiles[[]string] @@ -134,42 +164,34 @@ func NewProject(name string, kind Kind, currentDirectory string, host ProjectHos return slices.Sorted(maps.Values(data)) }) } - project.languageService = ls.NewLanguageService(project) project.markAsDirty() return project } -// FS implements LanguageServiceHost. +// FS implements compiler.CompilerHost. func (p *Project) FS() vfs.FS { return p.host.FS() } -// DefaultLibraryPath implements LanguageServiceHost. +// DefaultLibraryPath implements compiler.CompilerHost. func (p *Project) DefaultLibraryPath() string { return p.host.DefaultLibraryPath() } -// GetCompilerOptions implements LanguageServiceHost. -func (p *Project) GetCompilerOptions() *core.CompilerOptions { - return p.compilerOptions -} - -// GetCurrentDirectory implements LanguageServiceHost. +// GetCurrentDirectory implements compiler.CompilerHost. func (p *Project) GetCurrentDirectory() string { return p.currentDirectory } -// GetProjectVersion implements LanguageServiceHost. -func (p *Project) GetProjectVersion() int { - return p.version -} - -// GetRootFileNames implements LanguageServiceHost. func (p *Project) GetRootFileNames() []string { return slices.Collect(p.rootFileNames.Values()) } -// GetSourceFile implements LanguageServiceHost. +func (p *Project) GetCompilerOptions() *core.CompilerOptions { + return p.compilerOptions +} + +// GetSourceFile implements compiler.CompilerHost. func (p *Project) GetSourceFile(fileName string, path tspath.Path, languageVersion core.ScriptTarget) *ast.SourceFile { scriptKind := p.getScriptKind(fileName) if scriptInfo := p.getOrCreateScriptInfoAndAttachToProject(fileName, scriptKind); scriptInfo != nil { @@ -181,42 +203,32 @@ func (p *Project) GetSourceFile(fileName string, path tspath.Path, languageVersi oldSourceFile = p.program.GetSourceFileByPath(scriptInfo.path) oldCompilerOptions = p.program.GetCompilerOptions() } - return p.host.DocumentRegistry().AcquireDocument(scriptInfo, p.GetCompilerOptions(), oldSourceFile, oldCompilerOptions) + return p.host.DocumentRegistry().AcquireDocument(scriptInfo, p.compilerOptions, oldSourceFile, oldCompilerOptions) } return nil } -// GetProgram implements LanguageServiceHost. Updates the program if needed. +// Updates the program if needed. func (p *Project) GetProgram() *compiler.Program { p.updateIfDirty() return p.program } -// NewLine implements LanguageServiceHost. +// NewLine implements compiler.CompilerHost. func (p *Project) NewLine() string { return p.host.NewLine() } -// Trace implements LanguageServiceHost. +// Trace implements compiler.CompilerHost. func (p *Project) Trace(msg string) { p.log(msg) } -// GetDefaultLibraryPath implements ls.Host. +// GetDefaultLibraryPath implements compiler.CompilerHost. func (p *Project) GetDefaultLibraryPath() string { return p.host.DefaultLibraryPath() } -// GetScriptInfo implements ls.Host. -func (p *Project) GetScriptInfo(fileName string) ls.ScriptInfo { - return p.host.GetScriptInfoByPath(p.toPath(fileName)) -} - -// GetPositionEncoding implements ls.Host. -func (p *Project) GetPositionEncoding() lsproto.PositionEncodingKind { - return p.host.PositionEncoding() -} - func (p *Project) Name() string { return p.name } @@ -233,8 +245,24 @@ func (p *Project) CurrentProgram() *compiler.Program { return p.program } -func (p *Project) LanguageService() *ls.LanguageService { - return p.languageService +func (p *Project) GetLanguageServiceForRequest(ctx context.Context) (*ls.LanguageService, func()) { + if core.GetRequestID(ctx) == "" { + panic("context must already have a request ID") + } + program := p.GetProgram() + checkerPool := p.checkerPool + snapshot := &snapshot{ + project: p, + positionEncoding: p.host.PositionEncoding(), + program: program, + } + languageService := ls.NewLanguageService(ctx, snapshot) + cleanup := func() { + if checkerPool.isRequestCheckerInUse(core.GetRequestID(ctx)) { + panic(fmt.Errorf("checker for request ID %s not returned to pool at end of request", core.GetRequestID(ctx))) + } + } + return languageService, cleanup } func (p *Project) getRootFileWatchGlobs() []string { @@ -275,7 +303,7 @@ func (p *Project) getModuleResolutionWatchGlobs() (failedLookups map[tspath.Path return failedLookups, affectingLocaions } -func (p *Project) updateWatchers() { +func (p *Project) updateWatchers(ctx context.Context) { client := p.host.Client() if !p.host.IsWatchEnabled() || client == nil { return @@ -285,20 +313,20 @@ func (p *Project) updateWatchers() { failedLookupGlobs, affectingLocationGlobs := p.getModuleResolutionWatchGlobs() if rootFileGlobs != nil { - if updated, err := p.rootFilesWatch.update(rootFileGlobs); err != nil { + if updated, err := p.rootFilesWatch.update(ctx, rootFileGlobs); err != nil { p.log(fmt.Sprintf("Failed to update root file watch: %v", err)) } else if updated { p.log("Root file watches updated:\n" + formatFileList(rootFileGlobs, "\t", hr)) } } - if updated, err := p.failedLookupsWatch.update(failedLookupGlobs); err != nil { + if updated, err := p.failedLookupsWatch.update(ctx, failedLookupGlobs); err != nil { p.log(fmt.Sprintf("Failed to update failed lookup watch: %v", err)) } else if updated { p.log("Failed lookup watches updated:\n" + formatFileList(p.failedLookupsWatch.globs, "\t", hr)) } - if updated, err := p.affectingLocationsWatch.update(affectingLocationGlobs); err != nil { + if updated, err := p.affectingLocationsWatch.update(ctx, affectingLocationGlobs); err != nil { p.log(fmt.Sprintf("Failed to update affecting location watch: %v", err)) } else if updated { p.log("Affecting location watches updated:\n" + formatFileList(p.affectingLocationsWatch.globs, "\t", hr)) @@ -348,8 +376,8 @@ func (p *Project) markFileAsDirty(path tspath.Path) { } func (p *Project) markAsDirty() { - p.mu.Lock() - defer p.mu.Unlock() + p.dirtyStateMu.Lock() + defer p.dirtyStateMu.Unlock() if !p.dirty { p.dirty = true p.version++ @@ -363,8 +391,8 @@ func (p *Project) updateIfDirty() bool { } func (p *Project) onFileAddedOrRemoved(isSymlink bool) { - p.mu.Lock() - defer p.mu.Unlock() + p.dirtyStateMu.Lock() + defer p.dirtyStateMu.Unlock() p.hasAddedOrRemovedFiles = true if isSymlink { p.hasAddedOrRemovedSymlinks = true @@ -376,7 +404,12 @@ func (p *Project) onFileAddedOrRemoved(isSymlink bool) { // opposite of the return value in Strada, which was frequently inverted, // as in `updateProjectIfDirty()`. func (p *Project) updateGraph() bool { - // !!! + p.programMu.Lock() + defer p.programMu.Unlock() + if !p.dirty { + return false + } + p.log("Starting updateGraph: Project: " + p.name) oldProgram := p.program hasAddedOrRemovedFiles := p.hasAddedOrRemovedFiles @@ -414,18 +447,27 @@ func (p *Project) updateGraph() bool { } } - p.updateWatchers() + // TODO: this is currently always synchronously called by some kind of updating request, + // but in Strada we throttle, so at least sometimes this should be considered top-level? + p.updateWatchers(context.TODO()) return true } func (p *Project) updateProgram() { rootFileNames := p.GetRootFileNames() - compilerOptions := p.GetCompilerOptions() + compilerOptions := p.compilerOptions + if p.checkerPool != nil { + p.logf("Program %d used %d checker(s)", p.version, p.checkerPool.size()) + } p.program = compiler.NewProgram(compiler.ProgramOptions{ RootFiles: rootFileNames, Host: p, Options: compilerOptions, + CreateCheckerPool: func(program *compiler.Program) compiler.CheckerPool { + p.checkerPool = newCheckerPool(4, program, p.log) + return p.checkerPool + }, }) p.program.BindSourceFiles() diff --git a/internal/project/scriptinfo.go b/internal/project/scriptinfo.go index dfa18a67af..a33dcbd297 100644 --- a/internal/project/scriptinfo.go +++ b/internal/project/scriptinfo.go @@ -9,7 +9,7 @@ import ( "github.com/microsoft/typescript-go/internal/vfs" ) -var _ ls.ScriptInfo = (*ScriptInfo)(nil) +var _ ls.Script = (*ScriptInfo)(nil) type ScriptInfo struct { fileName string diff --git a/internal/project/service.go b/internal/project/service.go index d97cbf1bfe..2ec6068710 100644 --- a/internal/project/service.go +++ b/internal/project/service.go @@ -1,6 +1,8 @@ package project import ( + "context" + "errors" "fmt" "strings" "sync" @@ -85,8 +87,8 @@ func NewService(host ServiceHost, options ServiceOptions) *Service { realpathToScriptInfos: make(map[tspath.Path]map[*ScriptInfo]struct{}), } - service.converters = ls.NewConverters(options.PositionEncoding, func(fileName string) ls.ScriptInfo { - return service.GetScriptInfo(fileName) + service.converters = ls.NewConverters(options.PositionEncoding, func(fileName string) *ls.LineMap { + return service.GetScriptInfo(fileName).LineMap() }) return service @@ -177,13 +179,30 @@ func (s *Service) OpenFile(fileName string, fileContent string, scriptKind core. s.printProjects() } -func (s *Service) ChangeFile(fileName string, changes []ls.TextChange) { +func (s *Service) ChangeFile(document lsproto.VersionedTextDocumentIdentifier, changes []lsproto.TextDocumentContentChangeEvent) error { + fileName := ls.DocumentURIToFileName(document.Uri) path := s.toPath(fileName) - info := s.GetScriptInfoByPath(path) - if info == nil { - panic("scriptInfo not found") + scriptInfo := s.GetScriptInfoByPath(path) + if scriptInfo == nil { + return fmt.Errorf("file %s not found", fileName) + } + + textChanges := make([]ls.TextChange, len(changes)) + for i, change := range changes { + if partialChange := change.TextDocumentContentChangePartial; partialChange != nil { + textChanges[i] = s.converters.FromLSPTextChange(scriptInfo, partialChange) + } else if wholeChange := change.TextDocumentContentChangeWholeDocument; wholeChange != nil { + textChanges[i] = ls.TextChange{ + TextRange: core.NewTextRange(0, len(scriptInfo.Text())), + NewText: wholeChange.Text, + } + } else { + return errors.New("invalid change type") + } } - s.applyChangesToFile(info, changes) + + s.applyChangesToFile(scriptInfo, textChanges) + return nil } func (s *Service) CloseFile(fileName string) { @@ -208,6 +227,11 @@ func (s *Service) MarkFileSaved(fileName string, text string) { } } +func (s *Service) EnsureDefaultProjectForURI(url lsproto.DocumentUri) *Project { + _, project := s.EnsureDefaultProjectForFile(ls.DocumentURIToFileName(url)) + return project +} + func (s *Service) EnsureDefaultProjectForFile(fileName string) (*ScriptInfo, *Project) { path := s.toPath(fileName) if info := s.GetScriptInfoByPath(path); info != nil && !info.isOrphan() { @@ -233,7 +257,7 @@ func (s *Service) SourceFileCount() int { return s.documentRegistry.size() } -func (s *Service) OnWatchedFilesChanged(changes []*lsproto.FileEvent) error { +func (s *Service) OnWatchedFilesChanged(ctx context.Context, changes []*lsproto.FileEvent) error { for _, change := range changes { fileName := ls.DocumentURIToFileName(change.Uri) path := s.toPath(fileName) @@ -264,7 +288,7 @@ func (s *Service) OnWatchedFilesChanged(changes []*lsproto.FileEvent) error { client := s.host.Client() if client != nil { - return client.RefreshDiagnostics() + return client.RefreshDiagnostics(ctx) } return nil diff --git a/internal/project/service_test.go b/internal/project/service_test.go index 3658ea7ebb..748b5189af 100644 --- a/internal/project/service_test.go +++ b/internal/project/service_test.go @@ -7,7 +7,6 @@ import ( "github.com/microsoft/typescript-go/internal/bundled" "github.com/microsoft/typescript-go/internal/core" - "github.com/microsoft/typescript-go/internal/ls" "github.com/microsoft/typescript-go/internal/lsp/lsproto" "github.com/microsoft/typescript-go/internal/project" "github.com/microsoft/typescript-go/internal/testutil/projecttestutil" @@ -94,7 +93,32 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/x.ts", defaultFiles["/home/projects/TS/p1/src/x.ts"], core.ScriptKindTS, "") info, proj := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/x.ts") programBefore := proj.GetProgram() - service.ChangeFile("/home/projects/TS/p1/src/x.ts", []ls.TextChange{{TextRange: core.NewTextRange(17, 18), NewText: "2"}}) + err := service.ChangeFile( + lsproto.VersionedTextDocumentIdentifier{ + TextDocumentIdentifier: lsproto.TextDocumentIdentifier{ + Uri: "file:///home/projects/TS/p1/src/x.ts", + }, + Version: 1, + }, + []lsproto.TextDocumentContentChangeEvent{ + lsproto.TextDocumentContentChangePartialOrTextDocumentContentChangeWholeDocument{ + TextDocumentContentChangePartial: ptrTo(lsproto.TextDocumentContentChangePartial{ + Range: lsproto.Range{ + Start: lsproto.Position{ + Line: 0, + Character: 17, + }, + End: lsproto.Position{ + Line: 0, + Character: 18, + }, + }, + Text: "2", + }), + }, + }, + ) + assert.NilError(t, err) assert.Equal(t, info.Text(), "export const x = 2;") assert.Equal(t, proj.CurrentProgram(), programBefore) assert.Equal(t, programBefore.GetSourceFile("/home/projects/TS/p1/src/x.ts").Text(), "export const x = 1;") @@ -108,7 +132,32 @@ func TestService(t *testing.T) { _, proj := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/x.ts") programBefore := proj.GetProgram() indexFileBefore := programBefore.GetSourceFile("/home/projects/TS/p1/src/index.ts") - service.ChangeFile("/home/projects/TS/p1/src/x.ts", nil) + err := service.ChangeFile( + lsproto.VersionedTextDocumentIdentifier{ + TextDocumentIdentifier: lsproto.TextDocumentIdentifier{ + Uri: "file:///home/projects/TS/p1/src/x.ts", + }, + Version: 1, + }, + []lsproto.TextDocumentContentChangeEvent{ + lsproto.TextDocumentContentChangePartialOrTextDocumentContentChangeWholeDocument{ + TextDocumentContentChangePartial: ptrTo(lsproto.TextDocumentContentChangePartial{ + Range: lsproto.Range{ + Start: lsproto.Position{ + Line: 0, + Character: 0, + }, + End: lsproto.Position{ + Line: 0, + Character: 0, + }, + }, + Text: ";", + }), + }, + }, + ) + assert.NilError(t, err) assert.Equal(t, proj.GetProgram().GetSourceFile("/home/projects/TS/p1/src/index.ts"), indexFileBefore) }) @@ -120,7 +169,32 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/index.ts", files["/home/projects/TS/p1/src/index.ts"], core.ScriptKindTS, "") assert.Check(t, service.GetScriptInfo("/home/projects/TS/p1/y.ts") == nil) - service.ChangeFile("/home/projects/TS/p1/src/index.ts", []ls.TextChange{{TextRange: core.NewTextRange(0, 0), NewText: `import { y } from "../y";\n`}}) + err := service.ChangeFile( + lsproto.VersionedTextDocumentIdentifier{ + TextDocumentIdentifier: lsproto.TextDocumentIdentifier{ + Uri: "file:///home/projects/TS/p1/src/index.ts", + }, + Version: 1, + }, + []lsproto.TextDocumentContentChangeEvent{ + lsproto.TextDocumentContentChangePartialOrTextDocumentContentChangeWholeDocument{ + TextDocumentContentChangePartial: ptrTo(lsproto.TextDocumentContentChangePartial{ + Range: lsproto.Range{ + Start: lsproto.Position{ + Line: 0, + Character: 0, + }, + End: lsproto.Position{ + Line: 0, + Character: 0, + }, + }, + Text: `import { y } from "../y";\n`, + }), + }, + }, + ) + assert.NilError(t, err) service.EnsureDefaultProjectForFile("/home/projects/TS/p1/y.ts") }) }) @@ -245,7 +319,7 @@ func TestService(t *testing.T) { files := maps.Clone(defaultFiles) files["/home/projects/TS/p1/src/x.ts"] = `export const x = 2;` host.ReplaceFS(files) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeChanged, Uri: "file:///home/projects/TS/p1/src/x.ts", @@ -265,7 +339,7 @@ func TestService(t *testing.T) { files := maps.Clone(defaultFiles) files["/home/projects/TS/p1/src/x.ts"] = `export const x = 2;` host.ReplaceFS(files) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeChanged, Uri: "file:///home/projects/TS/p1/src/x.ts", @@ -294,7 +368,7 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/index.ts", files["/home/projects/TS/p1/src/index.ts"], core.ScriptKindTS, "") _, project := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/index.ts") program := project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) filesCopy := maps.Clone(files) filesCopy["/home/projects/TS/p1/tsconfig.json"] = `{ @@ -304,7 +378,7 @@ func TestService(t *testing.T) { } }` host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeChanged, Uri: "file:///home/projects/TS/p1/tsconfig.json", @@ -312,7 +386,7 @@ func TestService(t *testing.T) { })) program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) }) t.Run("delete explicitly included file", func(t *testing.T) { @@ -331,12 +405,12 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/index.ts", files["/home/projects/TS/p1/src/index.ts"], core.ScriptKindTS, "") _, project := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/index.ts") program := project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) filesCopy := maps.Clone(files) delete(filesCopy, "/home/projects/TS/p1/src/x.ts") host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeDeleted, Uri: "file:///home/projects/TS/p1/src/x.ts", @@ -344,7 +418,7 @@ func TestService(t *testing.T) { })) program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) assert.Check(t, program.GetSourceFile("/home/projects/TS/p1/src/x.ts") == nil) }) @@ -364,12 +438,12 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/x.ts", files["/home/projects/TS/p1/src/x.ts"], core.ScriptKindTS, "") _, project := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/x.ts") program := project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/x.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/x.ts"))), 0) filesCopy := maps.Clone(files) delete(filesCopy, "/home/projects/TS/p1/src/index.ts") host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeDeleted, Uri: "file:///home/projects/TS/p1/src/index.ts", @@ -377,7 +451,7 @@ func TestService(t *testing.T) { })) program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/x.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/x.ts"))), 1) }) t.Run("create explicitly included file", func(t *testing.T) { @@ -397,7 +471,7 @@ func TestService(t *testing.T) { program := project.GetProgram() // Initially should have an error because y.ts is missing - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) // Missing location should be watched assert.DeepEqual(t, host.ClientMock.WatchFilesCalls()[0].Watchers, []*lsproto.FileSystemWatcher{ @@ -425,7 +499,7 @@ func TestService(t *testing.T) { filesCopy := maps.Clone(files) filesCopy["/home/projects/TS/p1/src/y.ts"] = `export const y = 1;` host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeCreated, Uri: "file:///home/projects/TS/p1/src/y.ts", @@ -434,7 +508,7 @@ func TestService(t *testing.T) { // Error should be resolved program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) assert.Check(t, program.GetSourceFile("/home/projects/TS/p1/src/y.ts") != nil) }) @@ -455,7 +529,7 @@ func TestService(t *testing.T) { program := project.GetProgram() // Initially should have an error because z.ts is missing - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) // Missing location should be watched assert.Check(t, slices.ContainsFunc(host.ClientMock.WatchFilesCalls()[1].Watchers, func(w *lsproto.FileSystemWatcher) bool { @@ -466,7 +540,7 @@ func TestService(t *testing.T) { filesCopy := maps.Clone(files) filesCopy["/home/projects/TS/p1/src/z.ts"] = `export const z = 1;` host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeCreated, Uri: "file:///home/projects/TS/p1/src/z.ts", @@ -475,7 +549,7 @@ func TestService(t *testing.T) { // Error should be resolved and the new file should be included in the program program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) assert.Check(t, program.GetSourceFile("/home/projects/TS/p1/src/z.ts") != nil) }) @@ -496,13 +570,13 @@ func TestService(t *testing.T) { program := project.GetProgram() // Initially should have an error because declaration for 'a' is missing - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) // Add a new file through wildcard watch filesCopy := maps.Clone(files) filesCopy["/home/projects/TS/p1/src/a.ts"] = `const a = 1;` host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeCreated, Uri: "file:///home/projects/TS/p1/src/a.ts", @@ -511,7 +585,7 @@ func TestService(t *testing.T) { // Error should be resolved and the new file should be included in the program program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) assert.Check(t, program.GetSourceFile("/home/projects/TS/p1/src/a.ts") != nil) }) }) diff --git a/internal/project/watch.go b/internal/project/watch.go index 18db3a8c91..adf32a4dbb 100644 --- a/internal/project/watch.go +++ b/internal/project/watch.go @@ -1,6 +1,7 @@ package project import ( + "context" "slices" "github.com/microsoft/typescript-go/internal/lsp/lsproto" @@ -29,7 +30,7 @@ func newWatchedFiles[T any](client Client, watchKind lsproto.WatchKind, getGlobs } } -func (w *watchedFiles[T]) update(newData T) (updated bool, err error) { +func (w *watchedFiles[T]) update(ctx context.Context, newData T) (updated bool, err error) { newGlobs := w.getGlobs(newData) w.data = newData if slices.Equal(w.globs, newGlobs) { @@ -38,7 +39,7 @@ func (w *watchedFiles[T]) update(newData T) (updated bool, err error) { w.globs = newGlobs if w.watcherID != "" { - if err = w.client.UnwatchFiles(w.watcherID); err != nil { + if err = w.client.UnwatchFiles(ctx, w.watcherID); err != nil { return false, err } } @@ -52,7 +53,7 @@ func (w *watchedFiles[T]) update(newData T) (updated bool, err error) { Kind: &w.watchKind, }) } - watcherID, err := w.client.WatchFiles(watchers) + watcherID, err := w.client.WatchFiles(ctx, watchers) if err != nil { return false, err } diff --git a/internal/testrunner/compiler_runner.go b/internal/testrunner/compiler_runner.go index 5f6c48091b..f09a987296 100644 --- a/internal/testrunner/compiler_runner.go +++ b/internal/testrunner/compiler_runner.go @@ -455,7 +455,9 @@ func createHarnessTestFile(unit *testUnit, currentDirectory string) *harnessutil func (c *compilerTest) verifyUnionOrdering(t *testing.T) { t.Run("union ordering", func(t *testing.T) { - for _, c := range c.result.Program.GetTypeCheckers() { + checkers, done := c.result.Program.GetTypeCheckers(t.Context()) + defer done() + for _, c := range checkers { for union := range c.UnionTypes() { types := union.Types() diff --git a/internal/testutil/harnessutil/harnessutil.go b/internal/testutil/harnessutil/harnessutil.go index 71526477f7..3f44747893 100644 --- a/internal/testutil/harnessutil/harnessutil.go +++ b/internal/testutil/harnessutil/harnessutil.go @@ -537,11 +537,12 @@ func compileFilesWithHost( // ...ts.filter(longerErrors!, p => !ts.some(shorterErrors, p2 => ts.compareDiagnostics(p, p2) === ts.Comparison.EqualTo)), // ), // ] : postErrors; + ctx := context.Background() program := createProgram(host, options, rootFiles) var diagnostics []*ast.Diagnostic - diagnostics = append(diagnostics, program.GetSyntacticDiagnostics(context.Background(), nil)...) - diagnostics = append(diagnostics, program.GetSemanticDiagnostics(context.Background(), nil)...) - diagnostics = append(diagnostics, program.GetGlobalDiagnostics()...) + diagnostics = append(diagnostics, program.GetSyntacticDiagnostics(ctx, nil)...) + diagnostics = append(diagnostics, program.GetSemanticDiagnostics(ctx, nil)...) + diagnostics = append(diagnostics, program.GetGlobalDiagnostics(ctx)...) emitResult := program.Emit(compiler.EmitOptions{}) return newCompilationResult(options, program, emitResult, diagnostics, harnessOptions) diff --git a/internal/testutil/lstestutil/lstestutil.go b/internal/testutil/lstestutil/lstestutil.go index fc81641246..e4fe1d5f54 100644 --- a/internal/testutil/lstestutil/lstestutil.go +++ b/internal/testutil/lstestutil/lstestutil.go @@ -5,6 +5,8 @@ import ( "strings" "github.com/microsoft/typescript-go/internal/core" + "github.com/microsoft/typescript-go/internal/ls" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) type markerRange struct { @@ -15,9 +17,10 @@ type markerRange struct { } type Marker struct { - Filename string - Position int - Name string + Filename string + Position int + LSPosition lsproto.Position + Name string } type TestData struct { @@ -86,6 +89,18 @@ type TestFileInfo struct { // for FourSlashFile Content string } +// FileName implements ls.Script. +func (t *TestFileInfo) FileName() string { + return t.Filename +} + +// Text implements ls.Script. +func (t *TestFileInfo) Text() string { + return t.Content +} + +var _ ls.Script = (*TestFileInfo)(nil) + func parseFileContent(content string, filename string, markerMap map[string]*Marker, markers *[]*Marker) *TestFileInfo { // !!! chompLeadingSpace // !!! validate characters in markers @@ -98,7 +113,7 @@ func parseFileContent(content string, filename string, markerMap map[string]*Mar /// The total number of metacharacters removed from the file (so far) difference := 0 - /// Current position data + /// One-based current position data line := 1 column := 1 @@ -125,7 +140,7 @@ func parseFileContent(content string, filename string, markerMap map[string]*Mar position: (i - 1) - difference, sourcePosition: i - 1, sourceLine: line, - sourceColumn: column, + sourceColumn: column - 1, } } if previousCharacter == '*' && currentCharacter == '/' { @@ -157,10 +172,22 @@ func parseFileContent(content string, filename string, markerMap map[string]*Mar // Add the remaining text flush(-1) - return &TestFileInfo{ - Content: output, + // Set LS positions for markers + lineMap := ls.ComputeLineStarts(output) + converters := ls.NewConverters(lsproto.PositionEncodingKindUTF8, func(_ string) *ls.LineMap { + return lineMap + }) + + testFileInfo := &TestFileInfo{ Filename: filename, + Content: output, } + + for _, marker := range *markers { + marker.LSPosition = converters.PositionToLineAndCharacter(testFileInfo, core.TextPos(marker.Position)) + } + + return testFileInfo } func recordMarker( diff --git a/internal/testutil/projecttestutil/clientmock_generated.go b/internal/testutil/projecttestutil/clientmock_generated.go index e8bf1ab779..a19433ac8a 100644 --- a/internal/testutil/projecttestutil/clientmock_generated.go +++ b/internal/testutil/projecttestutil/clientmock_generated.go @@ -4,6 +4,7 @@ package projecttestutil import ( + "context" "sync" "github.com/microsoft/typescript-go/internal/lsp/lsproto" @@ -20,13 +21,13 @@ var _ project.Client = &ClientMock{} // // // make and configure a mocked project.Client // mockedClient := &ClientMock{ -// RefreshDiagnosticsFunc: func() error { +// RefreshDiagnosticsFunc: func(ctx context.Context) error { // panic("mock out the RefreshDiagnostics method") // }, -// UnwatchFilesFunc: func(handle project.WatcherHandle) error { +// UnwatchFilesFunc: func(ctx context.Context, handle project.WatcherHandle) error { // panic("mock out the UnwatchFiles method") // }, -// WatchFilesFunc: func(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { +// WatchFilesFunc: func(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { // panic("mock out the WatchFiles method") // }, // } @@ -37,26 +38,32 @@ var _ project.Client = &ClientMock{} // } type ClientMock struct { // RefreshDiagnosticsFunc mocks the RefreshDiagnostics method. - RefreshDiagnosticsFunc func() error + RefreshDiagnosticsFunc func(ctx context.Context) error // UnwatchFilesFunc mocks the UnwatchFiles method. - UnwatchFilesFunc func(handle project.WatcherHandle) error + UnwatchFilesFunc func(ctx context.Context, handle project.WatcherHandle) error // WatchFilesFunc mocks the WatchFiles method. - WatchFilesFunc func(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) + WatchFilesFunc func(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) // calls tracks calls to the methods. calls struct { // RefreshDiagnostics holds details about calls to the RefreshDiagnostics method. RefreshDiagnostics []struct { + // Ctx is the ctx argument value. + Ctx context.Context } // UnwatchFiles holds details about calls to the UnwatchFiles method. UnwatchFiles []struct { + // Ctx is the ctx argument value. + Ctx context.Context // Handle is the handle argument value. Handle project.WatcherHandle } // WatchFiles holds details about calls to the WatchFiles method. WatchFiles []struct { + // Ctx is the ctx argument value. + Ctx context.Context // Watchers is the watchers argument value. Watchers []*lsproto.FileSystemWatcher } @@ -67,9 +74,12 @@ type ClientMock struct { } // RefreshDiagnostics calls RefreshDiagnosticsFunc. -func (mock *ClientMock) RefreshDiagnostics() error { +func (mock *ClientMock) RefreshDiagnostics(ctx context.Context) error { callInfo := struct { - }{} + Ctx context.Context + }{ + Ctx: ctx, + } mock.lockRefreshDiagnostics.Lock() mock.calls.RefreshDiagnostics = append(mock.calls.RefreshDiagnostics, callInfo) mock.lockRefreshDiagnostics.Unlock() @@ -79,7 +89,7 @@ func (mock *ClientMock) RefreshDiagnostics() error { ) return errOut } - return mock.RefreshDiagnosticsFunc() + return mock.RefreshDiagnosticsFunc(ctx) } // RefreshDiagnosticsCalls gets all the calls that were made to RefreshDiagnostics. @@ -87,8 +97,10 @@ func (mock *ClientMock) RefreshDiagnostics() error { // // len(mockedClient.RefreshDiagnosticsCalls()) func (mock *ClientMock) RefreshDiagnosticsCalls() []struct { + Ctx context.Context } { var calls []struct { + Ctx context.Context } mock.lockRefreshDiagnostics.RLock() calls = mock.calls.RefreshDiagnostics @@ -97,10 +109,12 @@ func (mock *ClientMock) RefreshDiagnosticsCalls() []struct { } // UnwatchFiles calls UnwatchFilesFunc. -func (mock *ClientMock) UnwatchFiles(handle project.WatcherHandle) error { +func (mock *ClientMock) UnwatchFiles(ctx context.Context, handle project.WatcherHandle) error { callInfo := struct { + Ctx context.Context Handle project.WatcherHandle }{ + Ctx: ctx, Handle: handle, } mock.lockUnwatchFiles.Lock() @@ -112,7 +126,7 @@ func (mock *ClientMock) UnwatchFiles(handle project.WatcherHandle) error { ) return errOut } - return mock.UnwatchFilesFunc(handle) + return mock.UnwatchFilesFunc(ctx, handle) } // UnwatchFilesCalls gets all the calls that were made to UnwatchFiles. @@ -120,9 +134,11 @@ func (mock *ClientMock) UnwatchFiles(handle project.WatcherHandle) error { // // len(mockedClient.UnwatchFilesCalls()) func (mock *ClientMock) UnwatchFilesCalls() []struct { + Ctx context.Context Handle project.WatcherHandle } { var calls []struct { + Ctx context.Context Handle project.WatcherHandle } mock.lockUnwatchFiles.RLock() @@ -132,10 +148,12 @@ func (mock *ClientMock) UnwatchFilesCalls() []struct { } // WatchFiles calls WatchFilesFunc. -func (mock *ClientMock) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { +func (mock *ClientMock) WatchFiles(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { callInfo := struct { + Ctx context.Context Watchers []*lsproto.FileSystemWatcher }{ + Ctx: ctx, Watchers: watchers, } mock.lockWatchFiles.Lock() @@ -148,7 +166,7 @@ func (mock *ClientMock) WatchFiles(watchers []*lsproto.FileSystemWatcher) (proje ) return watcherHandleOut, errOut } - return mock.WatchFilesFunc(watchers) + return mock.WatchFilesFunc(ctx, watchers) } // WatchFilesCalls gets all the calls that were made to WatchFiles. @@ -156,9 +174,11 @@ func (mock *ClientMock) WatchFiles(watchers []*lsproto.FileSystemWatcher) (proje // // len(mockedClient.WatchFilesCalls()) func (mock *ClientMock) WatchFilesCalls() []struct { + Ctx context.Context Watchers []*lsproto.FileSystemWatcher } { var calls []struct { + Ctx context.Context Watchers []*lsproto.FileSystemWatcher } mock.lockWatchFiles.RLock() diff --git a/internal/testutil/projecttestutil/projecttestutil.go b/internal/testutil/projecttestutil/projecttestutil.go index ea70b913db..ed23751f49 100644 --- a/internal/testutil/projecttestutil/projecttestutil.go +++ b/internal/testutil/projecttestutil/projecttestutil.go @@ -1,12 +1,14 @@ package projecttestutil import ( + "context" "fmt" "io" "strings" "sync" "github.com/microsoft/typescript-go/internal/bundled" + "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/project" "github.com/microsoft/typescript-go/internal/vfs" "github.com/microsoft/typescript-go/internal/vfs/vfstest" @@ -80,3 +82,7 @@ func newProjectServiceHost(files map[string]string) *ProjectServiceHost { host.logger = project.NewLogger([]io.Writer{&host.output}, "", project.LogLevelVerbose) return host } + +func WithRequestID(ctx context.Context) context.Context { + return core.WithRequestID(ctx, "0") +} diff --git a/internal/testutil/tsbaseline/type_symbol_baseline.go b/internal/testutil/tsbaseline/type_symbol_baseline.go index ef4336490d..3cc5b44c9d 100644 --- a/internal/testutil/tsbaseline/type_symbol_baseline.go +++ b/internal/testutil/tsbaseline/type_symbol_baseline.go @@ -1,6 +1,7 @@ package tsbaseline import ( + "context" "fmt" "regexp" "slices" @@ -263,10 +264,10 @@ func newTypeWriterWalker(program *compiler.Program, hadErrorBaseline bool) *type } } -func (walker *typeWriterWalker) getTypeCheckerForCurrentFile() *checker.Checker { +func (walker *typeWriterWalker) getTypeCheckerForCurrentFile() (*checker.Checker, func()) { // If we don't use the right checker for the file, its contents won't be up to date // since the types/symbols baselines appear to depend on files having been checked. - return walker.program.GetTypeCheckerForFile(walker.currentSourceFile) + return walker.program.GetTypeCheckerForFile(context.Background(), walker.currentSourceFile) } type typeWriterResult struct { @@ -332,7 +333,8 @@ func (walker *typeWriterWalker) writeTypeOrSymbol(node *ast.Node, isSymbolWalk b actualPos := scanner.SkipTrivia(walker.currentSourceFile.Text(), node.Pos()) line, _ := scanner.GetLineAndCharacterOfPosition(walker.currentSourceFile, actualPos) sourceText := scanner.GetSourceTextOfNodeFromSourceFile(walker.currentSourceFile, node, false /*includeTrivia*/) - fileChecker := walker.getTypeCheckerForCurrentFile() + fileChecker, done := walker.getTypeCheckerForCurrentFile() + defer done() if !isSymbolWalk { // Don't try to get the type of something that's already a type.