From af44b2618d8972772fa365e5d0b523d8ee7656ae Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Tue, 29 Apr 2025 12:29:38 -0700 Subject: [PATCH 01/12] Asyncify --- internal/lsp/lsproto/jsonrpc.go | 115 ++++++++++++- internal/lsp/server.go | 294 +++++++++++++++++++++----------- 2 files changed, 306 insertions(+), 103 deletions(-) diff --git a/internal/lsp/lsproto/jsonrpc.go b/internal/lsp/lsproto/jsonrpc.go index e909e9442e..ca54805215 100644 --- a/internal/lsp/lsproto/jsonrpc.go +++ b/internal/lsp/lsproto/jsonrpc.go @@ -28,6 +28,13 @@ type ID struct { int int32 } +func NewID(rawValue IntegerOrString) *ID { + if rawValue.String != nil { + return &ID{str: *rawValue.String} + } + return &ID{int: *rawValue.Integer} +} + func NewIDString(str string) *ID { return &ID{str: str} } @@ -61,13 +68,92 @@ func (id *ID) MustInt() int32 { return id.int } -// TODO(jakebailey): NotificationMessage? Use RequestMessage without ID? +type MessageKind int + +const ( + MessageKindNotification MessageKind = iota + MessageKindRequest + MessageKindResponse +) + +type Message struct { + Kind MessageKind + msg any +} + +func (m *Message) AsRequest() *RequestMessage { + return m.msg.(*RequestMessage) +} + +func (m *Message) AsResponse() *ResponseMessage { + return m.msg.(*ResponseMessage) +} + +func (m *Message) UnmarshalJSON(data []byte) error { + var raw struct { + JSONRPC JSONRPCVersion `json:"jsonrpc"` + Method Method `json:"method"` + ID *ID `json:"id,omitempty"` + Params json.RawMessage `json:"params"` + Result any `json:"result,omitempty"` + Error *ResponseError `json:"error,omitempty"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("%w: %w", ErrInvalidRequest, err) + } + if raw.ID != nil && raw.Method == "" { + m.Kind = MessageKindResponse + m.msg = &ResponseMessage{ + JSONRPC: raw.JSONRPC, + ID: raw.ID, + Result: raw.Result, + Error: raw.Error, + } + return nil + } + + var params any + var err error + if len(raw.Params) > 0 { + params, err = unmarshalParams(raw.Method, raw.Params) + if err != nil { + return fmt.Errorf("%w: %w", ErrInvalidRequest, err) + } + } + + if raw.ID == nil { + m.Kind = MessageKindNotification + } else { + m.Kind = MessageKindRequest + } + + m.msg = &RequestMessage{ + JSONRPC: raw.JSONRPC, + ID: raw.ID, + Method: raw.Method, + Params: params, + } + + return nil +} + +func (m *Message) MarshalJSON() ([]byte, error) { + return json.Marshal(m.msg) +} + +func NewNotificationMessage(method Method, params any) *RequestMessage { + return &RequestMessage{ + JSONRPC: JSONRPCVersion{}, + Method: method, + Params: params, + } +} type RequestMessage struct { JSONRPC JSONRPCVersion `json:"jsonrpc"` ID *ID `json:"id,omitempty"` Method Method `json:"method"` - Params any `json:"params"` + Params any `json:"params,omitempty"` } func NewRequestMessage(method Method, id *ID, params any) *RequestMessage { @@ -78,6 +164,13 @@ func NewRequestMessage(method Method, id *ID, params any) *RequestMessage { } } +func (r *RequestMessage) Message() *Message { + return &Message{ + Kind: MessageKindRequest, + msg: r, + } +} + func (r *RequestMessage) UnmarshalJSON(data []byte) error { var raw struct { JSONRPC JSONRPCVersion `json:"jsonrpc"` @@ -108,8 +201,26 @@ type ResponseMessage struct { Error *ResponseError `json:"error,omitempty"` } +func (r *ResponseMessage) Message() *Message { + return &Message{ + Kind: MessageKindResponse, + msg: r, + } +} + type ResponseError struct { Code int32 `json:"code"` Message string `json:"message"` Data any `json:"data,omitempty"` } + +func (r *ResponseError) String() string { + if r == nil { + return "" + } + data, err := json.Marshal(r.Data) + if err != nil { + return fmt.Sprintf("[%d]: %s\n%v", r.Code, r.Message, data) + } + return fmt.Sprintf("[%d]: %s", r.Code, r.Message) +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 2931f40178..71e44545fe 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -1,6 +1,7 @@ package lsp import ( + "context" "encoding/json" "errors" "fmt" @@ -8,7 +9,7 @@ import ( "runtime/debug" "slices" "strings" - "time" + "sync" "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/ls" @@ -33,13 +34,18 @@ func NewServer(opts *ServerOptions) *Server { panic("Cwd is required") } return &Server{ - r: lsproto.NewBaseReader(opts.In), - w: lsproto.NewBaseWriter(opts.Out), - stderr: opts.Err, - cwd: opts.Cwd, - newLine: opts.NewLine, - fs: opts.FS, - defaultLibraryPath: opts.DefaultLibraryPath, + r: lsproto.NewBaseReader(opts.In), + w: lsproto.NewBaseWriter(opts.Out), + stderr: opts.Err, + fatalErrChan: make(chan error, 1), + requestQueue: make(chan *lsproto.RequestMessage, 100), + outgoingQueue: make(chan *lsproto.Message, 100), + pendingClientRequests: make(map[lsproto.ID]pendingClientRequest), + pendingServerRequests: make(map[lsproto.ID]chan *lsproto.ResponseMessage), + cwd: opts.Cwd, + newLine: opts.NewLine, + fs: opts.FS, + defaultLibraryPath: opts.DefaultLibraryPath, } } @@ -48,15 +54,25 @@ var ( _ project.Client = (*Server)(nil) ) +type pendingClientRequest struct { + req *lsproto.RequestMessage + cancel context.CancelFunc +} + type Server struct { r *lsproto.BaseReader w *lsproto.BaseWriter stderr io.Writer - clientSeq int32 - requestMethod string - requestTime time.Time + clientSeq int32 + fatalErrChan chan error + requestQueue chan *lsproto.RequestMessage + outgoingQueue chan *lsproto.Message + pendingClientRequests map[lsproto.ID]pendingClientRequest + pendingClientRequestsMu sync.Mutex + pendingServerRequests map[lsproto.ID]chan *lsproto.ResponseMessage + pendingServerRequestsMu sync.Mutex cwd string newLine core.NewLineKind @@ -110,7 +126,7 @@ func (s *Server) Client() project.Client { // WatchFiles implements project.Client. func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { watcherId := fmt.Sprintf("watcher-%d", s.watcherID) - if err := s.sendRequest(lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ + respChan, err := s.sendRequest(lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ Registrations: []*lsproto.Registration{ { Id: watcherId, @@ -120,10 +136,18 @@ func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.Watc })), }, }, - }); err != nil { + }) + + if err != nil { return "", fmt.Errorf("failed to register file watcher: %w", err) } + // TODO: timeout? + resp := <-respChan + if resp.Error != nil { + return "", fmt.Errorf("failed to register file watcher: %s", resp.Error.String()) + } + handle := project.WatcherHandle(watcherId) s.watchers.Add(handle) s.watcherID++ @@ -133,16 +157,24 @@ func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.Watc // UnwatchFiles implements project.Client. func (s *Server) UnwatchFiles(handle project.WatcherHandle) error { if s.watchers.Has(handle) { - if err := s.sendRequest(lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{ + respChan, err := s.sendRequest(lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{ Unregisterations: []*lsproto.Unregistration{ { Id: string(handle), Method: string(lsproto.MethodWorkspaceDidChangeWatchedFiles), }, }, - }); err != nil { + }) + + if err != nil { return fmt.Errorf("failed to unregister file watcher: %w", err) } + + resp := <-respChan + if resp.Error != nil { + return fmt.Errorf("failed to unregister file watcher: %s", resp.Error.String()) + } + s.watchers.Delete(handle) return nil } @@ -161,70 +193,124 @@ func (s *Server) RefreshDiagnostics() error { } func (s *Server) Run() error { + go s.dispatchLoop() + go s.writeLoop() + return s.readLoop() +} + +func (s *Server) readLoop() error { for { - req, err := s.read() + msg, err := s.read() if err != nil { if errors.Is(err, lsproto.ErrInvalidRequest) { - if err := s.sendError(nil, err); err != nil { - return err - } + s.sendError(nil, err) continue } return err } - // TODO: handle response messages - if req == nil { - continue - } - - if s.initializeParams == nil { + if s.initializeParams == nil && msg.Kind == lsproto.MessageKindRequest { + req := msg.AsRequest() if req.Method == lsproto.MethodInitialize { - if err := s.handleInitialize(req); err != nil { - return err - } + s.handleInitialize(req) } else { - if err := s.sendError(req.ID, lsproto.ErrServerNotInitialized); err != nil { - return err - } + s.sendError(req.ID, lsproto.ErrServerNotInitialized) } continue } - if err := s.handleMessage(req); err != nil { - return err + if msg.Kind == lsproto.MessageKindResponse { + resp := msg.AsResponse() + s.pendingServerRequestsMu.Lock() + if respChan, ok := s.pendingServerRequests[*resp.ID]; ok { + respChan <- resp + close(respChan) + delete(s.pendingServerRequests, *resp.ID) + } + s.pendingServerRequestsMu.Unlock() + } else { + req := msg.AsRequest() + if req.Method == lsproto.MethodCancelRequest { + go s.cancelRequest(req.Params.(*lsproto.CancelParams).Id) + } else { + s.requestQueue <- req + } } } } -func (s *Server) read() (*lsproto.RequestMessage, error) { +func (s *Server) cancelRequest(rawID lsproto.IntegerOrString) { + id := lsproto.NewID(rawID) + s.pendingClientRequestsMu.Lock() + defer s.pendingClientRequestsMu.Unlock() + if pendingReq, ok := s.pendingClientRequests[*id]; ok { + pendingReq.cancel() + delete(s.pendingClientRequests, *id) + } +} + +func (s *Server) read() (*lsproto.Message, error) { data, err := s.r.Read() if err != nil { return nil, err } - req := &lsproto.RequestMessage{} + req := &lsproto.Message{} if err := json.Unmarshal(data, req); err != nil { - res := &lsproto.ResponseMessage{} - if err = json.Unmarshal(data, res); err == nil { - // !!! TODO: handle response - return nil, nil - } return nil, fmt.Errorf("%w: %w", lsproto.ErrInvalidRequest, err) } return req, nil } -func (s *Server) sendRequest(method lsproto.Method, params any) error { +func (s *Server) dispatchLoop() { + for req := range s.requestQueue { + ctx := context.Background() + + if req.ID != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + s.pendingClientRequestsMu.Lock() + s.pendingClientRequests[*req.ID] = pendingClientRequest{ + req: req, + cancel: cancel, + } + s.pendingClientRequestsMu.Unlock() + } + + if err := s.handleRequestOrNotification(ctx, req); err != nil { + s.fatalErrChan <- err + return + } + } +} + +func (s *Server) writeLoop() { + for msg := range s.outgoingQueue { + data, err := json.Marshal(msg) + if err != nil { + s.fatalErrChan <- fmt.Errorf("failed to marshal message: %w", err) + continue + } + if err := s.w.Write(data); err != nil { + s.fatalErrChan <- fmt.Errorf("failed to write message: %w", err) + continue + } + } +} + +func (s *Server) sendRequest(method lsproto.Method, params any) (<-chan *lsproto.ResponseMessage, error) { s.clientSeq++ id := lsproto.NewIDString(fmt.Sprintf("ts%d", s.clientSeq)) req := lsproto.NewRequestMessage(method, id, params) - data, err := json.Marshal(req) - if err != nil { - return err - } - return s.w.Write(data) + + responseChan := make(chan *lsproto.ResponseMessage, 1) + s.pendingServerRequestsMu.Lock() + s.pendingServerRequests[*id] = responseChan + s.pendingServerRequestsMu.Unlock() + + s.outgoingQueue <- req.Message() + return responseChan, nil } func (s *Server) sendNotification(method lsproto.Method, params any) error { @@ -236,20 +322,20 @@ func (s *Server) sendNotification(method lsproto.Method, params any) error { return s.w.Write(data) } -func (s *Server) sendResult(id *lsproto.ID, result any) error { - return s.sendResponse(&lsproto.ResponseMessage{ +func (s *Server) sendResult(id *lsproto.ID, result any) { + s.sendResponse(&lsproto.ResponseMessage{ ID: id, Result: result, }) } -func (s *Server) sendError(id *lsproto.ID, err error) error { +func (s *Server) sendError(id *lsproto.ID, err error) { code := lsproto.ErrInternalError.Code if errCode := (*lsproto.ErrorCode)(nil); errors.As(err, &errCode) { code = errCode.Code } // TODO(jakebailey): error data - return s.sendResponse(&lsproto.ResponseMessage{ + s.sendResponse(&lsproto.ResponseMessage{ ID: id, Error: &lsproto.ResponseError{ Code: code, @@ -258,63 +344,56 @@ func (s *Server) sendError(id *lsproto.ID, err error) error { }) } -func (s *Server) sendResponse(resp *lsproto.ResponseMessage) error { - if !s.requestTime.IsZero() { - s.logger.PerfTrace(fmt.Sprintf("%s: %s", s.requestMethod, time.Since(s.requestTime))) - } - data, err := json.Marshal(resp) - if err != nil { - return err - } - return s.w.Write(data) +func (s *Server) sendResponse(resp *lsproto.ResponseMessage) { + s.outgoingQueue <- resp.Message() } -func (s *Server) handleMessage(req *lsproto.RequestMessage) error { - s.requestTime = time.Now() - s.requestMethod = string(req.Method) - +func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params switch params.(type) { case *lsproto.InitializeParams: - return s.sendError(req.ID, lsproto.ErrInvalidRequest) + s.sendError(req.ID, lsproto.ErrInvalidRequest) + return nil case *lsproto.InitializedParams: - return s.handleInitialized(req) + s.handleInitialized(ctx, req) case *lsproto.DidOpenTextDocumentParams: - return s.handleDidOpen(req) + s.handleDidOpen(ctx, req) case *lsproto.DidChangeTextDocumentParams: - return s.handleDidChange(req) + s.handleDidChange(ctx, req) case *lsproto.DidSaveTextDocumentParams: - return s.handleDidSave(req) + s.handleDidSave(ctx, req) case *lsproto.DidCloseTextDocumentParams: - return s.handleDidClose(req) + s.handleDidClose(ctx, req) case *lsproto.DidChangeWatchedFilesParams: - return s.handleDidChangeWatchedFiles(req) + s.handleDidChangeWatchedFiles(ctx, req) case *lsproto.DocumentDiagnosticParams: - return s.handleDocumentDiagnostic(req) + s.handleDocumentDiagnostic(ctx, req) case *lsproto.HoverParams: - return s.handleHover(req) + s.handleHover(ctx, req) case *lsproto.DefinitionParams: - return s.handleDefinition(req) + s.handleDefinition(ctx, req) case *lsproto.CompletionParams: - return s.handleCompletion(req) + s.handleCompletion(ctx, req) default: switch req.Method { case lsproto.MethodShutdown: s.projectService.Close() - return s.sendResult(req.ID, nil) + s.sendResult(req.ID, nil) + return nil case lsproto.MethodExit: return nil default: s.Log("unknown method", req.Method) if req.ID != nil { - return s.sendError(req.ID, lsproto.ErrInvalidRequest) + s.sendError(req.ID, lsproto.ErrInvalidRequest) } return nil } } + return nil } -func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { +func (s *Server) handleInitialize(req *lsproto.RequestMessage) { s.initializeParams = req.Params.(*lsproto.InitializeParams) s.positionEncoding = lsproto.PositionEncodingKindUTF16 @@ -324,7 +403,7 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { } } - return s.sendResult(req.ID, &lsproto.InitializeResult{ + s.sendResult(req.ID, &lsproto.InitializeResult{ ServerInfo: &lsproto.ServerInfo{ Name: "typescript-go", Version: ptrTo(core.Version), @@ -361,7 +440,7 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { }) } -func (s *Server) handleInitialized(req *lsproto.RequestMessage) error { +func (s *Server) handleInitialized(ctx context.Context, req *lsproto.RequestMessage) error { if s.initializeParams.Capabilities.Workspace.DidChangeWatchedFiles != nil && *s.initializeParams.Capabilities.Workspace.DidChangeWatchedFiles.DynamicRegistration { s.watchEnabled = true } @@ -380,24 +459,26 @@ func (s *Server) handleInitialized(req *lsproto.RequestMessage) error { return nil } -func (s *Server) handleDidOpen(req *lsproto.RequestMessage) error { +func (s *Server) handleDidOpen(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidOpenTextDocumentParams) s.projectService.OpenFile(ls.DocumentURIToFileName(params.TextDocument.Uri), params.TextDocument.Text, ls.LanguageKindToScriptKind(params.TextDocument.LanguageId), "") return nil } -func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { +func (s *Server) handleDidChange(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidChangeTextDocumentParams) scriptInfo := s.projectService.GetScriptInfo(ls.DocumentURIToFileName(params.TextDocument.Uri)) if scriptInfo == nil { - return s.sendError(req.ID, lsproto.ErrRequestFailed) + s.sendError(req.ID, lsproto.ErrRequestFailed) + return nil } changes := make([]ls.TextChange, len(params.ContentChanges)) for i, change := range params.ContentChanges { if partialChange := change.TextDocumentContentChangePartial; partialChange != nil { if textChange, err := s.converters.FromLSPTextChange(partialChange, scriptInfo.FileName()); err != nil { - return s.sendError(req.ID, err) + s.sendError(req.ID, err) + return nil } else { changes[i] = textChange } @@ -407,7 +488,8 @@ func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { NewText: wholeChange.Text, } } else { - return s.sendError(req.ID, lsproto.ErrInvalidRequest) + s.sendError(req.ID, lsproto.ErrInvalidRequest) + return nil } } @@ -415,36 +497,37 @@ func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { return nil } -func (s *Server) handleDidSave(req *lsproto.RequestMessage) error { +func (s *Server) handleDidSave(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidSaveTextDocumentParams) s.projectService.MarkFileSaved(ls.DocumentURIToFileName(params.TextDocument.Uri), *params.Text) return nil } -func (s *Server) handleDidClose(req *lsproto.RequestMessage) error { +func (s *Server) handleDidClose(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidCloseTextDocumentParams) s.projectService.CloseFile(ls.DocumentURIToFileName(params.TextDocument.Uri)) return nil } -func (s *Server) handleDidChangeWatchedFiles(req *lsproto.RequestMessage) error { +func (s *Server) handleDidChangeWatchedFiles(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidChangeWatchedFilesParams) return s.projectService.OnWatchedFilesChanged(params.Changes) } -func (s *Server) handleDocumentDiagnostic(req *lsproto.RequestMessage) error { +func (s *Server) handleDocumentDiagnostic(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DocumentDiagnosticParams) file, project := s.getFileAndProject(params.TextDocument.Uri) diagnostics := project.LanguageService().GetDocumentDiagnostics(file.FileName()) lspDiagnostics := make([]*lsproto.Diagnostic, len(diagnostics)) for i, diag := range diagnostics { if lspDiagnostic, err := s.converters.ToLSPDiagnostic(diag); err != nil { - return s.sendError(req.ID, err) + s.sendError(req.ID, err) + return nil } else { lspDiagnostics[i] = lspDiagnostic } } - return s.sendResult(req.ID, &lsproto.DocumentDiagnosticReport{ + s.sendResult(req.ID, &lsproto.DocumentDiagnosticReport{ RelatedFullDocumentDiagnosticReport: &lsproto.RelatedFullDocumentDiagnosticReport{ FullDocumentDiagnosticReport: lsproto.FullDocumentDiagnosticReport{ Kind: lsproto.StringLiteralFull{}, @@ -452,18 +535,20 @@ func (s *Server) handleDocumentDiagnostic(req *lsproto.RequestMessage) error { }, }, }) + return nil } -func (s *Server) handleHover(req *lsproto.RequestMessage) error { +func (s *Server) handleHover(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.HoverParams) file, project := s.getFileAndProject(params.TextDocument.Uri) pos, err := s.converters.LineAndCharacterToPositionForFile(params.Position, file.FileName()) if err != nil { - return s.sendError(req.ID, err) + s.sendError(req.ID, err) + return nil } hoverText := project.LanguageService().ProvideHover(file.FileName(), pos) - return s.sendResult(req.ID, &lsproto.Hover{ + s.sendResult(req.ID, &lsproto.Hover{ Contents: lsproto.MarkupContentOrMarkedStringOrMarkedStrings{ MarkupContent: &lsproto.MarkupContent{ Kind: lsproto.MarkupKindMarkdown, @@ -471,35 +556,41 @@ func (s *Server) handleHover(req *lsproto.RequestMessage) error { }, }, }) + + return nil } -func (s *Server) handleDefinition(req *lsproto.RequestMessage) error { +func (s *Server) handleDefinition(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DefinitionParams) file, project := s.getFileAndProject(params.TextDocument.Uri) pos, err := s.converters.LineAndCharacterToPositionForFile(params.Position, file.FileName()) if err != nil { - return s.sendError(req.ID, err) + s.sendError(req.ID, err) + return nil } locations := project.LanguageService().ProvideDefinitions(file.FileName(), pos) lspLocations := make([]lsproto.Location, len(locations)) for i, loc := range locations { if lspLocation, err := s.converters.ToLSPLocation(loc); err != nil { - return s.sendError(req.ID, err) + s.sendError(req.ID, err) + return nil } else { lspLocations[i] = lspLocation } } - return s.sendResult(req.ID, &lsproto.Definition{Locations: &lspLocations}) + s.sendResult(req.ID, &lsproto.Definition{Locations: &lspLocations}) + return nil } -func (s *Server) handleCompletion(req *lsproto.RequestMessage) (messageErr error) { +func (s *Server) handleCompletion(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.CompletionParams) file, project := s.getFileAndProject(params.TextDocument.Uri) pos, err := s.converters.LineAndCharacterToPositionForFile(params.Position, file.FileName()) if err != nil { - return s.sendError(req.ID, err) + s.sendError(req.ID, err) + return nil } // !!! remove this after completions is fully ported/tested @@ -507,7 +598,7 @@ func (s *Server) handleCompletion(req *lsproto.RequestMessage) (messageErr error if r := recover(); r != nil { stack := debug.Stack() s.Log("panic obtaining completions:", r, string(stack)) - messageErr = s.sendResult(req.ID, &lsproto.CompletionList{}) + s.sendResult(req.ID, &lsproto.CompletionList{}) } }() // !!! get user preferences @@ -517,7 +608,8 @@ func (s *Server) handleCompletion(req *lsproto.RequestMessage) (messageErr error params.Context, s.initializeParams.Capabilities.TextDocument.Completion, &ls.UserPreferences{}) - return s.sendResult(req.ID, list) + s.sendResult(req.ID, list) + return nil } func (s *Server) getFileAndProject(uri lsproto.DocumentUri) (*project.ScriptInfo, *project.Project) { From 5ed46bf99231b569a743370198b1e320bdad91d3 Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Tue, 29 Apr 2025 12:48:48 -0700 Subject: [PATCH 02/12] Shut down correctly --- internal/lsp/server.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 71e44545fe..6c5c21d7c5 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -195,18 +195,25 @@ func (s *Server) RefreshDiagnostics() error { func (s *Server) Run() error { go s.dispatchLoop() go s.writeLoop() - return s.readLoop() + go s.readLoop() + err := <-s.fatalErrChan + return err } -func (s *Server) readLoop() error { +func (s *Server) readLoop() { for { msg, err := s.read() if err != nil { + if errors.Is(err, io.EOF) { + s.fatalErrChan <- nil + return + } if errors.Is(err, lsproto.ErrInvalidRequest) { s.sendError(nil, err) continue } - return err + s.fatalErrChan <- err + return } if s.initializeParams == nil && msg.Kind == lsproto.MessageKindRequest { @@ -381,6 +388,7 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R s.sendResult(req.ID, nil) return nil case lsproto.MethodExit: + s.fatalErrChan <- nil return nil default: s.Log("unknown method", req.Method) From 43ac686012fd9299f728a736ba8dde1b313582d7 Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Wed, 30 Apr 2025 15:06:12 -0700 Subject: [PATCH 03/12] Parallelize requests and give each access to its own checker via pool --- internal/checker/checker_test.go | 3 +- internal/compiler/checkerpool.go | 90 ++++++++ internal/compiler/emitHost.go | 5 +- internal/compiler/program.go | 85 ++++--- internal/core/context.go | 21 ++ internal/execute/tsc.go | 11 +- internal/ls/api.go | 9 +- internal/ls/completions.go | 26 ++- internal/ls/completions_test.go | 7 +- internal/ls/converters.go | 141 +++--------- internal/ls/definition.go | 25 ++- internal/ls/diagnostics.go | 67 +++++- internal/ls/host.go | 20 +- internal/ls/hover.go | 45 +++- internal/ls/languageservice.go | 79 +++---- internal/ls/utilities.go | 5 +- internal/lsp/lsproto/jsonrpc.go | 8 + internal/lsp/server.go | 212 ++++++------------ internal/project/checkerpool.go | 163 ++++++++++++++ internal/project/project.go | 69 ++++-- internal/project/scriptinfo.go | 2 +- internal/project/service.go | 36 ++- internal/project/service_test.go | 79 ++++++- internal/testrunner/compiler_runner.go | 4 +- internal/testutil/harnessutil/harnessutil.go | 7 +- internal/testutil/lstestutil/lstestutil.go | 14 +- .../tsbaseline/type_symbol_baseline.go | 8 +- 27 files changed, 796 insertions(+), 445 deletions(-) create mode 100644 internal/compiler/checkerpool.go create mode 100644 internal/core/context.go create mode 100644 internal/project/checkerpool.go diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index 3922e6ffd3..3458c8687c 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -39,7 +39,8 @@ foo.bar;` } p := compiler.NewProgram(opts) p.BindSourceFiles() - c := p.GetTypeChecker() + c, done := p.GetTypeChecker(t.Context()) + defer done() file := p.GetSourceFile("/foo.ts") interfaceId := file.Statements.Nodes[0].Name() varId := file.Statements.Nodes[1].AsVariableStatement().DeclarationList.AsVariableDeclarationList().Declarations.Nodes[0].Name() diff --git a/internal/compiler/checkerpool.go b/internal/compiler/checkerpool.go new file mode 100644 index 0000000000..17949cd568 --- /dev/null +++ b/internal/compiler/checkerpool.go @@ -0,0 +1,90 @@ +package compiler + +import ( + "context" + "iter" + "slices" + "sync" + + "github.com/microsoft/typescript-go/internal/ast" + "github.com/microsoft/typescript-go/internal/checker" + "github.com/microsoft/typescript-go/internal/core" +) + +type CheckerPool interface { + GetChecker(ctx context.Context) (*checker.Checker, func()) + GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) + GetAllCheckers(ctx context.Context) ([]*checker.Checker, func()) + Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] +} + +type checkerPool struct { + checkerCount int + program *Program + + createCheckersOnce sync.Once + checkers []*checker.Checker + fileAssociations map[*ast.SourceFile]*checker.Checker +} + +var _ CheckerPool = (*checkerPool)(nil) + +func newCheckerPool(checkerCount int, program *Program) *checkerPool { + pool := &checkerPool{ + program: program, + checkerCount: checkerCount, + checkers: make([]*checker.Checker, checkerCount), + } + + return pool +} + +func (p *checkerPool) GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) { + p.createCheckers() + checker := p.fileAssociations[file] + return checker, noop +} + +func (p *checkerPool) GetChecker(ctx context.Context) (*checker.Checker, func()) { + p.createCheckers() + checker := p.checkers[0] + return checker, noop +} + +func (p *checkerPool) createCheckers() { + p.createCheckersOnce.Do(func() { + wg := core.NewWorkGroup(p.program.singleThreaded()) + for i := range p.checkerCount { + wg.Queue(func() { + p.checkers[i] = checker.NewChecker(p.program) + }) + } + + wg.RunAndWait() + + p.fileAssociations = make(map[*ast.SourceFile]*checker.Checker, len(p.program.files)) + for i, file := range p.program.files { + p.fileAssociations[file] = p.checkers[i%p.checkerCount] + } + }) +} + +func (p *checkerPool) GetAllCheckers(ctx context.Context) ([]*checker.Checker, func()) { + p.createCheckers() + return p.checkers, noop +} + +func (p *checkerPool) Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] { + checkerIndex := slices.Index(p.checkers, checker) + return func(yield func(*ast.SourceFile) bool) { + for i, file := range p.program.files { + if i%p.checkerCount == checkerIndex { + if !yield(file) { + return + } + } + } + } +} + +func noop() {} diff --git a/internal/compiler/emitHost.go b/internal/compiler/emitHost.go index 0f73ed940f..3cd269a11d 100644 --- a/internal/compiler/emitHost.go +++ b/internal/compiler/emitHost.go @@ -1,6 +1,8 @@ package compiler import ( + "context" + "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/printer" @@ -32,7 +34,8 @@ func (host *emitHost) WriteFile(fileName string, text string, writeByteOrderMark } func (host *emitHost) GetEmitResolver(file *ast.SourceFile, skipDiagnostics bool) printer.EmitResolver { - checker := host.program.GetTypeCheckerForFile(file) + checker, done := host.program.GetTypeCheckerForFile(context.Background(), file) + defer done() return checker.GetEmitResolver(file, skipDiagnostics) } diff --git a/internal/compiler/program.go b/internal/compiler/program.go index cb1a6c5e3e..18540f24be 100644 --- a/internal/compiler/program.go +++ b/internal/compiler/program.go @@ -27,6 +27,7 @@ type ProgramOptions struct { SingleThreaded core.Tristate ProjectReference []core.ProjectReference ConfigFileParsingDiagnostics []*ast.Diagnostic + CheckerPool CheckerPool } type Program struct { @@ -35,9 +36,7 @@ type Program struct { compilerOptions *core.CompilerOptions configFileName string nodeModules map[string]*ast.SourceFile - checkers []*checker.Checker - checkersOnce sync.Once - checkersByFile map[*ast.SourceFile]*checker.Checker + checkerPool CheckerPool currentDirectory string configFileParsingDiagnostics []*ast.Diagnostic @@ -79,6 +78,10 @@ func NewProgram(options ProgramOptions) *Program { if p.compilerOptions == nil { p.compilerOptions = &core.CompilerOptions{} } + p.checkerPool = options.CheckerPool + if p.checkerPool == nil { + p.checkerPool = newCheckerPool(core.IfElse(p.singleThreaded(), 1, 4), p) + } // p.maxNodeModuleJsDepth = p.options.MaxNodeModuleJsDepth @@ -212,11 +215,12 @@ func (p *Program) BindSourceFiles() { } func (p *Program) CheckSourceFiles(ctx context.Context) { - p.createCheckers() wg := core.NewWorkGroup(p.singleThreaded()) - for index, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(ctx) + defer done() + for index, checker := range checkers { wg.Queue(func() { - for i := index; i < len(p.files); i += len(p.checkers) { + for i := index; i < len(p.files); i += len(checkers) { checker.CheckSourceFile(ctx, p.files[i]) } }) @@ -224,44 +228,21 @@ func (p *Program) CheckSourceFiles(ctx context.Context) { wg.RunAndWait() } -func (p *Program) createCheckers() { - p.checkersOnce.Do(func() { - p.checkers = make([]*checker.Checker, core.IfElse(p.singleThreaded(), 1, 4)) - wg := core.NewWorkGroup(p.singleThreaded()) - for i := range p.checkers { - wg.Queue(func() { - p.checkers[i] = checker.NewChecker(p) - }) - } - wg.RunAndWait() - p.checkersByFile = make(map[*ast.SourceFile]*checker.Checker) - for i, file := range p.files { - p.checkersByFile[file] = p.checkers[i%len(p.checkers)] - } - }) -} - // Return the type checker associated with the program. -func (p *Program) GetTypeChecker() *checker.Checker { - p.createCheckers() - // Just use the first (and possibly only) checker for checker requests. Such requests are likely - // to obtain types through multiple API calls and we want to ensure that those types are created - // by the same checker so they can interoperate. - return p.checkers[0] +func (p *Program) GetTypeChecker(ctx context.Context) (*checker.Checker, func()) { + return p.checkerPool.GetChecker(ctx) } -func (p *Program) GetTypeCheckers() []*checker.Checker { - p.createCheckers() - return p.checkers +func (p *Program) GetTypeCheckers(ctx context.Context) ([]*checker.Checker, func()) { + return p.checkerPool.GetAllCheckers(ctx) } // Return a checker for the given file. We may have multiple checkers in concurrent scenarios and this // method returns the checker that was tasked with checking the file. Note that it isn't possible to mix // types obtained from different checkers, so only non-type data (such as diagnostics or string // representations of types) should be obtained from checkers returned by this method. -func (p *Program) GetTypeCheckerForFile(file *ast.SourceFile) *checker.Checker { - p.createCheckers() - return p.checkersByFile[file] +func (p *Program) GetTypeCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) { + return p.checkerPool.GetCheckerForFile(ctx, file) } func (p *Program) GetResolvedModule(file *ast.SourceFile, moduleReference string) *ast.SourceFile { @@ -294,17 +275,19 @@ func (p *Program) GetSemanticDiagnostics(ctx context.Context, sourceFile *ast.So return p.getDiagnosticsHelper(ctx, sourceFile, true /*ensureBound*/, true /*ensureChecked*/, p.getSemanticDiagnosticsForFile) } -func (p *Program) GetGlobalDiagnostics() []*ast.Diagnostic { - p.createCheckers() +func (p *Program) GetGlobalDiagnostics(ctx context.Context) []*ast.Diagnostic { var globalDiagnostics []*ast.Diagnostic - for _, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(ctx) + defer done() + for _, checker := range checkers { globalDiagnostics = append(globalDiagnostics, checker.GetGlobalDiagnostics()...) } + return SortAndDeduplicateDiagnostics(globalDiagnostics) } -func (p *Program) GetOptionsDiagnostics() []*ast.Diagnostic { - return SortAndDeduplicateDiagnostics(append(p.GetGlobalDiagnostics(), p.getOptionsDiagnosticsOfConfigFile()...)) +func (p *Program) GetOptionsDiagnostics(ctx context.Context) []*ast.Diagnostic { + return SortAndDeduplicateDiagnostics(append(p.GetGlobalDiagnostics(ctx), p.getOptionsDiagnosticsOfConfigFile()...)) } func (p *Program) getOptionsDiagnosticsOfConfigFile() []*ast.Diagnostic { @@ -332,14 +315,20 @@ func (p *Program) getSemanticDiagnosticsForFile(ctx context.Context, sourceFile if checker.SkipTypeChecking(sourceFile, p.compilerOptions) { return nil } + var fileChecker *checker.Checker + var done func() if sourceFile != nil { - fileChecker = p.GetTypeCheckerForFile(sourceFile) + fileChecker, done = p.checkerPool.GetCheckerForFile(ctx, sourceFile) + defer done() } diags := slices.Clip(sourceFile.BindDiagnostics()) + checkers, closeCheckers := p.checkerPool.GetAllCheckers(ctx) + defer closeCheckers() + // Ask for diags from all checkers; checking one file may add diagnostics to other files. // These are deduplicated later. - for _, checker := range p.checkers { + for _, checker := range checkers { if sourceFile == nil || checker == fileChecker { diags = append(diags, checker.GetDiagnostics(ctx, sourceFile)...) } else { @@ -482,7 +471,9 @@ func (p *Program) SymbolCount() int { for _, file := range p.files { count += file.SymbolCount } - for _, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(context.Background()) + defer done() + for _, checker := range checkers { count += int(checker.SymbolCount) } return count @@ -490,7 +481,9 @@ func (p *Program) SymbolCount() int { func (p *Program) TypeCount() int { var count int - for _, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(context.Background()) + defer done() + for _, checker := range checkers { count += int(checker.TypeCount) } return count @@ -498,7 +491,9 @@ func (p *Program) TypeCount() int { func (p *Program) InstantiationCount() int { var count int - for _, checker := range p.checkers { + checkers, done := p.checkerPool.GetAllCheckers(context.Background()) + defer done() + for _, checker := range checkers { count += int(checker.TotalInstantiationCount) } return count diff --git a/internal/core/context.go b/internal/core/context.go new file mode 100644 index 0000000000..f755cdedba --- /dev/null +++ b/internal/core/context.go @@ -0,0 +1,21 @@ +package core + +import "context" + +type key int + +var requestIDKey key + +func WithRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, requestIDKey, id) +} + +func GetRequestID(ctx context.Context) string { + if ctx == nil { + return "" + } + if id, ok := ctx.Value(requestIDKey).(string); ok { + return id + } + return "" +} diff --git a/internal/execute/tsc.go b/internal/execute/tsc.go index 959817f882..e26528f36c 100644 --- a/internal/execute/tsc.go +++ b/internal/execute/tsc.go @@ -239,25 +239,26 @@ type compileAndEmitResult struct { func compileAndEmit(sys System, program *compiler.Program, reportDiagnostic diagnosticReporter) (result compileAndEmitResult) { // todo: check if third return needed after execute is fully implemented + ctx := context.Background() options := program.Options() allDiagnostics := program.GetConfigFileParsingDiagnostics() // todo: early exit logic and append diagnostics - diagnostics := program.GetSyntacticDiagnostics(context.Background(), nil) + diagnostics := program.GetSyntacticDiagnostics(ctx, nil) if len(diagnostics) == 0 { bindStart := time.Now() - _ = program.GetBindDiagnostics(context.Background(), nil) + _ = program.GetBindDiagnostics(ctx, nil) result.bindTime = time.Since(bindStart) - diagnostics = append(diagnostics, program.GetOptionsDiagnostics()...) + diagnostics = append(diagnostics, program.GetOptionsDiagnostics(ctx)...) if options.ListFilesOnly.IsFalse() { // program.GetBindDiagnostics(nil) - diagnostics = append(diagnostics, program.GetGlobalDiagnostics()...) + diagnostics = append(diagnostics, program.GetGlobalDiagnostics(ctx)...) } } if len(diagnostics) == 0 { checkStart := time.Now() - diagnostics = append(diagnostics, program.GetSemanticDiagnostics(context.Background(), nil)...) + diagnostics = append(diagnostics, program.GetSemanticDiagnostics(ctx, nil)...) result.checkTime = time.Since(checkStart) } // TODO: declaration diagnostics diff --git a/internal/ls/api.go b/internal/ls/api.go index 62fd233a6f..cb0827459d 100644 --- a/internal/ls/api.go +++ b/internal/ls/api.go @@ -15,7 +15,7 @@ var ( ) func (l *LanguageService) GetSymbolAtPosition(fileName string, position int) (*ast.Symbol, error) { - program, file := l.tryGetProgramAndFile(fileName) + _, file := l.tryGetProgramAndFile(fileName) if file == nil { return nil, fmt.Errorf("%w: %s", ErrNoSourceFile, fileName) } @@ -23,17 +23,16 @@ func (l *LanguageService) GetSymbolAtPosition(fileName string, position int) (*a if node == nil { return nil, fmt.Errorf("%w: %s:%d", ErrNoTokenAtPosition, fileName, position) } - checker := program.GetTypeChecker() + checker := l.GetTypeChecker(file) return checker.GetSymbolAtLocation(node), nil } func (l *LanguageService) GetSymbolAtLocation(node *ast.Node) *ast.Symbol { - program := l.GetProgram() - checker := program.GetTypeChecker() + checker := l.GetTypeChecker(ast.GetSourceFileOfNode(node)) return checker.GetSymbolAtLocation(node) } func (l *LanguageService) GetTypeOfSymbol(symbol *ast.Symbol) *checker.Type { - checker := l.GetProgram().GetTypeChecker() + checker := l.GetTypeChecker(nil /*file*/) return checker.GetTypeOfSymbolAtLocation(symbol, nil) } diff --git a/internal/ls/completions.go b/internal/ls/completions.go index 262688bda4..ec4abb82ca 100644 --- a/internal/ls/completions.go +++ b/internal/ls/completions.go @@ -22,14 +22,21 @@ import ( ) func (l *LanguageService) ProvideCompletion( - fileName string, - position int, + documentURI lsproto.DocumentUri, + position lsproto.Position, context *lsproto.CompletionContext, clientOptions *lsproto.CompletionClientCapabilities, preferences *UserPreferences, -) *lsproto.CompletionList { - program, file := l.getProgramAndFile(fileName) - return l.getCompletionsAtPosition(program, file, position, context, preferences, clientOptions) +) (*lsproto.CompletionList, error) { + program, file := l.getProgramAndFile(documentURI) + return l.getCompletionsAtPosition( + program, + file, + int(l.converters.LineAndCharacterToPosition(file, position)), + context, + preferences, + clientOptions, + ), nil } // *completionDataData | *completionDataKeyword @@ -287,7 +294,7 @@ func (l *LanguageService) getCompletionsAtPosition( // !!! label completions - data := getCompletionData(program, file, position, preferences) + data := getCompletionData(program, l.GetTypeChecker(file), file, position, preferences) if data == nil { return nil } @@ -313,8 +320,7 @@ func (l *LanguageService) getCompletionsAtPosition( } } -func getCompletionData(program *compiler.Program, file *ast.SourceFile, position int, preferences *UserPreferences) completionData { - typeChecker := program.GetTypeChecker() +func getCompletionData(program *compiler.Program, typeChecker *checker.Checker, file *ast.SourceFile, position int, preferences *UserPreferences) completionData { inCheckedFile := isCheckedFile(file, program.GetCompilerOptions()) currentToken := astnav.GetTokenAtPosition(file, position) @@ -1560,7 +1566,7 @@ func (l *LanguageService) getCompletionEntriesFromSymbols( ) (uniqueNames core.Set[string], sortedEntries []*lsproto.CompletionItem) { closestSymbolDeclaration := getClosestSymbolDeclaration(data.contextToken, data.location) useSemicolons := probablyUsesSemicolons(file) - typeChecker := program.GetTypeChecker() + typeChecker := l.GetTypeChecker(file) isMemberCompletion := isMemberCompletionKind(data.completionKind) optionalReplacementSpan := getOptionalReplacementSpan(data.location, file) // Tracks unique names. @@ -1694,7 +1700,7 @@ func (l *LanguageService) createCompletionItem( source := getSourceFromOrigin(origin) var labelDetails *lsproto.CompletionItemLabelDetails - typeChecker := program.GetTypeChecker() + typeChecker := l.GetTypeChecker(file) insertQuestionDot := originIsNullableMember(origin) useBraces := originIsSymbolMember(origin) || needsConvertPropertyAccess if originIsThisType(origin) { diff --git a/internal/ls/completions_test.go b/internal/ls/completions_test.go index 5b78a4a700..ad48c2ef3d 100644 --- a/internal/ls/completions_test.go +++ b/internal/ls/completions_test.go @@ -1575,12 +1575,13 @@ func runTest(t *testing.T, files map[string]string, expected map[string]*testCas if !ok { t.Fatalf("No marker found for '%s'", markerName) } - completionList := languageService.ProvideCompletion( - mainFileName, - marker.Position, + completionList, err := languageService.ProvideCompletion( + ls.FileNameToDocumentURI(mainFileName), + marker.LSPosition, context, capabilities, preferences) + assert.NilError(t, err) if expectedResult.isIncludes { assertIncludesItem(t, completionList, expectedResult.list) } else { diff --git a/internal/ls/converters.go b/internal/ls/converters.go index 204832b263..690870df03 100644 --- a/internal/ls/converters.go +++ b/internal/ls/converters.go @@ -8,137 +8,60 @@ import ( "unicode/utf16" "unicode/utf8" - "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/core" - "github.com/microsoft/typescript-go/internal/diagnostics" "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -type ScriptInfo interface { - Text() string - LineMap() *LineMap -} - type Converters struct { - getScriptInfo func(fileName string) ScriptInfo + getLineMap func(fileName string) *LineMap positionEncoding lsproto.PositionEncodingKind } -func NewConverters(positionEncoding lsproto.PositionEncodingKind, getScriptInfo func(fileName string) ScriptInfo) *Converters { +type Script interface { + FileName() string + Text() string +} + +func NewConverters(positionEncoding lsproto.PositionEncodingKind, getLineMap func(fileName string) *LineMap) *Converters { return &Converters{ - getScriptInfo: getScriptInfo, + getLineMap: getLineMap, positionEncoding: positionEncoding, } } -func (c *Converters) ToLSPRange(fileName string, textRange core.TextRange) (lsproto.Range, error) { - scriptInfo := c.getScriptInfo(fileName) - if scriptInfo == nil { - return lsproto.Range{}, fmt.Errorf("no script info found for %s", fileName) - } - +func (c *Converters) ToLSPRange(script Script, textRange core.TextRange) lsproto.Range { return lsproto.Range{ - Start: c.PositionToLineAndCharacter(scriptInfo, core.TextPos(textRange.Pos())), - End: c.PositionToLineAndCharacter(scriptInfo, core.TextPos(textRange.End())), - }, nil + Start: c.PositionToLineAndCharacter(script, core.TextPos(textRange.Pos())), + End: c.PositionToLineAndCharacter(script, core.TextPos(textRange.End())), + } } -func (c *Converters) FromLSPRange(textRange lsproto.Range, fileName string) (core.TextRange, error) { - scriptInfo := c.getScriptInfo(fileName) - if scriptInfo == nil { - return core.TextRange{}, fmt.Errorf("no script info found for %s", fileName) - } +func (c *Converters) FromLSPRange(script Script, textRange lsproto.Range) core.TextRange { return core.NewTextRange( - int(c.LineAndCharacterToPosition(scriptInfo, textRange.Start)), - int(c.LineAndCharacterToPosition(scriptInfo, textRange.End)), - ), nil + int(c.LineAndCharacterToPosition(script, textRange.Start)), + int(c.LineAndCharacterToPosition(script, textRange.End)), + ) } -func (c *Converters) FromLSPTextChange(change *lsproto.TextDocumentContentChangePartial, fileName string) (TextChange, error) { - textRange, err := c.FromLSPRange(change.Range, fileName) - if err != nil { - return TextChange{}, fmt.Errorf("error converting range: %w", err) - } +func (c *Converters) FromLSPTextChange(script Script, change *lsproto.TextDocumentContentChangePartial) TextChange { return TextChange{ - TextRange: textRange, + TextRange: c.FromLSPRange(script, change.Range), NewText: change.Text, - }, nil -} - -func (c *Converters) ToLSPLocation(location Location) (lsproto.Location, error) { - rng, err := c.ToLSPRange(location.FileName, location.Range) - if err != nil { - return lsproto.Location{}, err } - return lsproto.Location{ - Uri: FileNameToDocumentURI(location.FileName), - Range: rng, - }, nil } -func (c *Converters) FromLSPLocation(location lsproto.Location) (Location, error) { - fileName := DocumentURIToFileName(location.Uri) - rng, err := c.FromLSPRange(location.Range, fileName) - if err != nil { - return Location{}, err - } - return Location{ - FileName: fileName, - Range: rng, - }, nil -} - -func (c *Converters) ToLSPDiagnostic(diagnostic *ast.Diagnostic) (*lsproto.Diagnostic, error) { - textRange, err := c.ToLSPRange(diagnostic.File().FileName(), diagnostic.Loc()) - if err != nil { - return nil, fmt.Errorf("error converting diagnostic range: %w", err) - } - - var severity lsproto.DiagnosticSeverity - switch diagnostic.Category() { - case diagnostics.CategorySuggestion: - severity = lsproto.DiagnosticSeverityHint - case diagnostics.CategoryMessage: - severity = lsproto.DiagnosticSeverityInformation - case diagnostics.CategoryWarning: - severity = lsproto.DiagnosticSeverityWarning - default: - severity = lsproto.DiagnosticSeverityError - } - - relatedInformation := make([]*lsproto.DiagnosticRelatedInformation, 0, len(diagnostic.RelatedInformation())) - for _, related := range diagnostic.RelatedInformation() { - relatedRange, err := c.ToLSPRange(related.File().FileName(), related.Loc()) - if err != nil { - return nil, fmt.Errorf("error converting related info range: %w", err) - } - relatedInformation = append(relatedInformation, &lsproto.DiagnosticRelatedInformation{ - Location: lsproto.Location{ - Uri: FileNameToDocumentURI(related.File().FileName()), - Range: relatedRange, - }, - Message: related.Message(), - }) +func (c *Converters) ToLSPLocation(script Script, rng core.TextRange) lsproto.Location { + return lsproto.Location{ + Uri: FileNameToDocumentURI(script.FileName()), + Range: c.ToLSPRange(script, rng), } - - return &lsproto.Diagnostic{ - Range: textRange, - Code: &lsproto.IntegerOrString{ - Integer: ptrTo(diagnostic.Code()), - }, - Severity: &severity, - Message: diagnostic.Message(), - Source: ptrTo("ts"), - RelatedInformation: &relatedInformation, - }, nil } -func (c *Converters) LineAndCharacterToPositionForFile(lineAndCharacter lsproto.Position, fileName string) (int, error) { - scriptInfo := c.getScriptInfo(fileName) - if scriptInfo == nil { - return 0, fmt.Errorf("no script info found for %s", fileName) +func (c *Converters) FromLSPLocation(script Script, rng lsproto.Range) Location { + return Location{ + FileName: script.FileName(), + Range: c.FromLSPRange(script, rng), } - return int(c.LineAndCharacterToPosition(scriptInfo, lineAndCharacter)), nil } func LanguageKindToScriptKind(languageID lsproto.LanguageKind) core.ScriptKind { @@ -202,10 +125,10 @@ func FileNameToDocumentURI(fileName string) lsproto.DocumentUri { return lsproto.DocumentUri("file://" + fileName) } -func (c *Converters) LineAndCharacterToPosition(scriptInfo ScriptInfo, lineAndCharacter lsproto.Position) core.TextPos { +func (c *Converters) LineAndCharacterToPosition(script Script, lineAndCharacter lsproto.Position) core.TextPos { // UTF-8/16 0-indexed line and character to UTF-8 offset - lineMap := scriptInfo.LineMap() + lineMap := c.getLineMap(script.FileName()) line := core.TextPos(lineAndCharacter.Line) char := core.TextPos(lineAndCharacter.Character) @@ -222,7 +145,7 @@ func (c *Converters) LineAndCharacterToPosition(scriptInfo ScriptInfo, lineAndCh var utf8Char core.TextPos var utf16Char core.TextPos - for i, r := range scriptInfo.Text()[start:] { + for i, r := range script.Text()[start:] { u16Len := core.TextPos(utf16.RuneLen(r)) if utf16Char+u16Len > char { break @@ -234,10 +157,10 @@ func (c *Converters) LineAndCharacterToPosition(scriptInfo ScriptInfo, lineAndCh return start + utf8Char } -func (c *Converters) PositionToLineAndCharacter(scriptInfo ScriptInfo, position core.TextPos) lsproto.Position { +func (c *Converters) PositionToLineAndCharacter(script Script, position core.TextPos) lsproto.Position { // UTF-8 offset to UTF-8/16 0-indexed line and character - lineMap := scriptInfo.LineMap() + lineMap := c.getLineMap(script.FileName()) line, isLineStart := slices.BinarySearch(lineMap.LineStarts, position) if !isLineStart { @@ -254,7 +177,7 @@ func (c *Converters) PositionToLineAndCharacter(scriptInfo ScriptInfo, position character = position - start } else { // We need to rescan the text as UTF-16 to find the character offset. - for _, r := range scriptInfo.Text()[start:position] { + for _, r := range script.Text()[start:position] { character += core.TextPos(utf16.RuneLen(r)) } } diff --git a/internal/ls/definition.go b/internal/ls/definition.go index f4ed52124f..9eeed3072d 100644 --- a/internal/ls/definition.go +++ b/internal/ls/definition.go @@ -4,17 +4,19 @@ import ( "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/astnav" "github.com/microsoft/typescript-go/internal/core" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" "github.com/microsoft/typescript-go/internal/scanner" ) -func (l *LanguageService) ProvideDefinitions(fileName string, position int) []Location { - program, file := l.getProgramAndFile(fileName) - node := astnav.GetTouchingPropertyName(file, position) +func (l *LanguageService) ProvideDefinition(documentURI lsproto.DocumentUri, position lsproto.Position) (*lsproto.Definition, error) { + file := l.getSourceFile(documentURI) + node := astnav.GetTouchingPropertyName(file, int(l.converters.LineAndCharacterToPosition(file, position))) if node.Kind == ast.KindSourceFile { - return nil + return nil, nil } - checker := program.GetTypeChecker() + checker := l.GetTypeChecker(file) + if symbol := checker.GetSymbolAtLocation(node); symbol != nil { if symbol.Flags&ast.SymbolFlagsAlias != 0 { if resolved, ok := checker.ResolveAlias(symbol); ok { @@ -22,18 +24,17 @@ func (l *LanguageService) ProvideDefinitions(fileName string, position int) []Lo } } - locations := make([]Location, 0, len(symbol.Declarations)) + locations := make([]lsproto.Location, 0, len(symbol.Declarations)) for _, decl := range symbol.Declarations { file := ast.GetSourceFileOfNode(decl) loc := decl.Loc pos := scanner.GetTokenPosOfNode(decl, file, false /*includeJSDoc*/) - - locations = append(locations, Location{ - FileName: file.FileName(), - Range: core.NewTextRange(pos, loc.End()), + locations = append(locations, lsproto.Location{ + Uri: FileNameToDocumentURI(file.FileName()), + Range: l.converters.ToLSPRange(file, core.NewTextRange(pos, loc.End())), }) } - return locations + return &lsproto.Definition{Locations: &locations}, nil } - return nil + return nil, nil } diff --git a/internal/ls/diagnostics.go b/internal/ls/diagnostics.go index 1994ec740f..b16a92566c 100644 --- a/internal/ls/diagnostics.go +++ b/internal/ls/diagnostics.go @@ -2,14 +2,71 @@ package ls import ( "context" - "slices" "github.com/microsoft/typescript-go/internal/ast" + "github.com/microsoft/typescript-go/internal/diagnostics" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -func (l *LanguageService) GetDocumentDiagnostics(fileName string) []*ast.Diagnostic { - program, file := l.getProgramAndFile(fileName) +func (l *LanguageService) GetDocumentDiagnostics(documentURI lsproto.DocumentUri) (*lsproto.DocumentDiagnosticReport, error) { + program, file := l.getProgramAndFile(documentURI) syntaxDiagnostics := program.GetSyntacticDiagnostics(context.Background(), file) - semanticDiagnostics := program.GetSemanticDiagnostics(context.Background(), file) - return slices.Concat(syntaxDiagnostics, semanticDiagnostics) + var lspDiagnostics []*lsproto.Diagnostic + if len(syntaxDiagnostics) != 0 { + lspDiagnostics = make([]*lsproto.Diagnostic, len(syntaxDiagnostics)) + for i, diag := range syntaxDiagnostics { + lspDiagnostics[i] = toLSPDiagnostic(diag, l.converters) + } + } else { + checker := l.GetTypeChecker(file) + semanticDiagnostics := checker.GetDiagnostics(context.Background(), file) + lspDiagnostics = make([]*lsproto.Diagnostic, len(semanticDiagnostics)) + for i, diag := range semanticDiagnostics { + lspDiagnostics[i] = toLSPDiagnostic(diag, l.converters) + } + } + return &lsproto.DocumentDiagnosticReport{ + RelatedFullDocumentDiagnosticReport: &lsproto.RelatedFullDocumentDiagnosticReport{ + FullDocumentDiagnosticReport: lsproto.FullDocumentDiagnosticReport{ + Kind: lsproto.StringLiteralFull{}, + Items: lspDiagnostics, + }, + }, + }, nil +} + +func toLSPDiagnostic(diagnostic *ast.Diagnostic, converters *Converters) *lsproto.Diagnostic { + var severity lsproto.DiagnosticSeverity + switch diagnostic.Category() { + case diagnostics.CategorySuggestion: + severity = lsproto.DiagnosticSeverityHint + case diagnostics.CategoryMessage: + severity = lsproto.DiagnosticSeverityInformation + case diagnostics.CategoryWarning: + severity = lsproto.DiagnosticSeverityWarning + default: + severity = lsproto.DiagnosticSeverityError + } + + relatedInformation := make([]*lsproto.DiagnosticRelatedInformation, 0, len(diagnostic.RelatedInformation())) + for _, related := range diagnostic.RelatedInformation() { + relatedInformation = append(relatedInformation, &lsproto.DiagnosticRelatedInformation{ + Location: lsproto.Location{ + Uri: FileNameToDocumentURI(related.File().FileName()), + Range: converters.ToLSPRange(related.File(), related.Loc()), + }, + Message: related.Message(), + }) + } + + return &lsproto.Diagnostic{ + Range: converters.ToLSPRange(diagnostic.File(), diagnostic.Loc()), + Code: &lsproto.IntegerOrString{ + Integer: ptrTo(diagnostic.Code()), + }, + Severity: &severity, + Message: diagnostic.Message(), + Source: ptrTo("ts"), + RelatedInformation: &relatedInformation, + } } diff --git a/internal/ls/host.go b/internal/ls/host.go index d54cd1b3ad..493596fc45 100644 --- a/internal/ls/host.go +++ b/internal/ls/host.go @@ -1,30 +1,12 @@ package ls import ( - "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/compiler" - "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/lsp/lsproto" - "github.com/microsoft/typescript-go/internal/tspath" - "github.com/microsoft/typescript-go/internal/vfs" ) type Host interface { - FS() vfs.FS - DefaultLibraryPath() string - GetCurrentDirectory() string - NewLine() string - Trace(msg string) - GetProjectVersion() int - // GetRootFileNames was called GetScriptFileNames in the original code. - GetRootFileNames() []string - // GetCompilerOptions was called GetCompilationSettings in the original code. - GetCompilerOptions() *core.CompilerOptions - GetSourceFile(fileName string, path tspath.Path, languageVersion core.ScriptTarget) *ast.SourceFile - // This responsibility was moved from the language service to the project, - // because they were bidirectionally interdependent. GetProgram() *compiler.Program - GetDefaultLibraryPath() string GetPositionEncoding() lsproto.PositionEncodingKind - GetScriptInfo(fileName string) ScriptInfo + GetLineMap(fileName string) *LineMap } diff --git a/internal/ls/hover.go b/internal/ls/hover.go index fdf5273f40..fd20995ab1 100644 --- a/internal/ls/hover.go +++ b/internal/ls/hover.go @@ -1,16 +1,53 @@ package ls import ( + "strings" + "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/astnav" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -func (l *LanguageService) ProvideHover(fileName string, position int) string { - program, file := l.getProgramAndFile(fileName) - node := astnav.GetTouchingPropertyName(file, position) +func (l *LanguageService) ProvideHover(documentURI lsproto.DocumentUri, position lsproto.Position) (*lsproto.Hover, error) { + file := l.getSourceFile(documentURI) + node := astnav.GetTouchingPropertyName(file, int(l.converters.LineAndCharacterToPosition(file, position))) if node.Kind == ast.KindSourceFile { // Avoid giving quickInfo for the sourceFile as a whole. + return nil, nil + } + result := l.GetTypeChecker(file).GetQuickInfoAtLocation(node) + if result != "" { + return &lsproto.Hover{ + Contents: lsproto.MarkupContentOrMarkedStringOrMarkedStrings{ + MarkupContent: &lsproto.MarkupContent{ + Kind: lsproto.MarkupKindMarkdown, + Value: codeFence("typescript", result), + }, + }, + }, nil + } + return nil, nil +} + +func codeFence(lang string, code string) string { + if code == "" { return "" } - return program.GetTypeChecker().GetQuickInfoAtLocation(node) + ticks := 3 + for strings.Contains(code, strings.Repeat("`", ticks)) { + ticks++ + } + var result strings.Builder + result.Grow(len(code) + len(lang) + 2*ticks + 2) + for range ticks { + result.WriteByte('`') + } + result.WriteString(lang) + result.WriteByte('\n') + result.WriteString(code) + result.WriteByte('\n') + for range ticks { + result.WriteByte('`') + } + return result.String() } diff --git a/internal/ls/languageservice.go b/internal/ls/languageservice.go index 2fc19e4549..900fce3a14 100644 --- a/internal/ls/languageservice.go +++ b/internal/ls/languageservice.go @@ -1,60 +1,51 @@ package ls import ( + "context" + "github.com/microsoft/typescript-go/internal/ast" + "github.com/microsoft/typescript-go/internal/checker" "github.com/microsoft/typescript-go/internal/compiler" - "github.com/microsoft/typescript-go/internal/core" - "github.com/microsoft/typescript-go/internal/tspath" - "github.com/microsoft/typescript-go/internal/vfs" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -var _ compiler.CompilerHost = (*LanguageService)(nil) - type LanguageService struct { - converters *Converters - host Host + ctx context.Context + host Host + converters *Converters + disposables []func() } -func NewLanguageService(host Host) *LanguageService { +func NewLanguageService(ctx context.Context, host Host) *LanguageService { return &LanguageService{ + ctx: ctx, host: host, - converters: NewConverters(host.GetPositionEncoding(), host.GetScriptInfo), + converters: NewConverters(host.GetPositionEncoding(), host.GetLineMap), } } -// FS implements compiler.CompilerHost. -func (l *LanguageService) FS() vfs.FS { - return l.host.FS() -} - -// DefaultLibraryPath implements compiler.CompilerHost. -func (l *LanguageService) DefaultLibraryPath() string { - return l.host.DefaultLibraryPath() -} - -// GetCurrentDirectory implements compiler.CompilerHost. -func (l *LanguageService) GetCurrentDirectory() string { - return l.host.GetCurrentDirectory() -} - -// NewLine implements compiler.CompilerHost. -func (l *LanguageService) NewLine() string { - return l.host.NewLine() -} - -// Trace implements compiler.CompilerHost. -func (l *LanguageService) Trace(msg string) { - l.host.Trace(msg) +// GetProgram updates the program if the project version has changed. +func (l *LanguageService) GetProgram() *compiler.Program { + return l.host.GetProgram() } -// GetSourceFile implements compiler.CompilerHost. -func (l *LanguageService) GetSourceFile(fileName string, path tspath.Path, languageVersion core.ScriptTarget) *ast.SourceFile { - return l.host.GetSourceFile(fileName, path, languageVersion) +func (l *LanguageService) GetTypeChecker(file *ast.SourceFile) *checker.Checker { + var checker *checker.Checker + var done func() + if file == nil { + checker, done = l.GetProgram().GetTypeChecker(l.ctx) + } else { + checker, done = l.GetProgram().GetTypeCheckerForFile(l.ctx, file) + } + l.disposables = append(l.disposables, done) + return checker } -// GetProgram updates the program if the project version has changed. -func (l *LanguageService) GetProgram() *compiler.Program { - return l.host.GetProgram() +func (l *LanguageService) Dispose() { + for _, dispose := range l.disposables { + dispose() + } + l.disposables = nil } func (l *LanguageService) tryGetProgramAndFile(fileName string) (*compiler.Program, *ast.SourceFile) { @@ -63,7 +54,17 @@ func (l *LanguageService) tryGetProgramAndFile(fileName string) (*compiler.Progr return program, file } -func (l *LanguageService) getProgramAndFile(fileName string) (*compiler.Program, *ast.SourceFile) { +func (l *LanguageService) getSourceFile(documentURI lsproto.DocumentUri) *ast.SourceFile { + fileName := DocumentURIToFileName(documentURI) + _, file := l.tryGetProgramAndFile(fileName) + if file == nil { + return nil + } + return file +} + +func (l *LanguageService) getProgramAndFile(documentURI lsproto.DocumentUri) (*compiler.Program, *ast.SourceFile) { + fileName := DocumentURIToFileName(documentURI) program, file := l.tryGetProgramAndFile(fileName) if file == nil { panic("file not found: " + fileName) diff --git a/internal/ls/utilities.go b/internal/ls/utilities.go index 2f159d5639..02286e0b8c 100644 --- a/internal/ls/utilities.go +++ b/internal/ls/utilities.go @@ -268,10 +268,7 @@ func (l *LanguageService) createLspRangeFromNode(node *ast.Node, file *ast.Sourc } func (l *LanguageService) createLspRangeFromBounds(start, end int, file *ast.SourceFile) *lsproto.Range { - lspRange, err := l.converters.ToLSPRange(file.FileName(), core.NewTextRange(start, end)) - if err != nil { - panic(err) - } + lspRange := l.converters.ToLSPRange(file, core.NewTextRange(start, end)) return &lspRange } diff --git a/internal/lsp/lsproto/jsonrpc.go b/internal/lsp/lsproto/jsonrpc.go index ca54805215..36e756fc7a 100644 --- a/internal/lsp/lsproto/jsonrpc.go +++ b/internal/lsp/lsproto/jsonrpc.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "strconv" ) type JSONRPCVersion struct{} @@ -39,6 +40,13 @@ func NewIDString(str string) *ID { return &ID{str: str} } +func (id *ID) String() string { + if id.str != "" { + return id.str + } + return strconv.Itoa(int(id.int)) +} + func (id *ID) MarshalJSON() ([]byte, error) { if id.str != "" { return json.Marshal(id.str) diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 6c5c21d7c5..fc88cd5bb8 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -8,7 +8,6 @@ import ( "io" "runtime/debug" "slices" - "strings" "sync" "github.com/microsoft/typescript-go/internal/core" @@ -87,7 +86,6 @@ type Server struct { watchers core.Set[project.WatcherHandle] logger *project.Logger projectService *project.Service - converters *ls.Converters } // FS implements project.ServiceHost. @@ -126,7 +124,7 @@ func (s *Server) Client() project.Client { // WatchFiles implements project.Client. func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { watcherId := fmt.Sprintf("watcher-%d", s.watcherID) - respChan, err := s.sendRequest(lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ + _, err := s.sendRequest(lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ Registrations: []*lsproto.Registration{ { Id: watcherId, @@ -142,12 +140,6 @@ func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.Watc return "", fmt.Errorf("failed to register file watcher: %w", err) } - // TODO: timeout? - resp := <-respChan - if resp.Error != nil { - return "", fmt.Errorf("failed to register file watcher: %s", resp.Error.String()) - } - handle := project.WatcherHandle(watcherId) s.watchers.Add(handle) s.watcherID++ @@ -157,7 +149,7 @@ func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.Watc // UnwatchFiles implements project.Client. func (s *Server) UnwatchFiles(handle project.WatcherHandle) error { if s.watchers.Has(handle) { - respChan, err := s.sendRequest(lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{ + _, err := s.sendRequest(lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{ Unregisterations: []*lsproto.Unregistration{ { Id: string(handle), @@ -170,11 +162,6 @@ func (s *Server) UnwatchFiles(handle project.WatcherHandle) error { return fmt.Errorf("failed to unregister file watcher: %w", err) } - resp := <-respChan - if resp.Error != nil { - return fmt.Errorf("failed to unregister file watcher: %s", resp.Error.String()) - } - s.watchers.Delete(handle) return nil } @@ -185,7 +172,7 @@ func (s *Server) UnwatchFiles(handle project.WatcherHandle) error { // RefreshDiagnostics implements project.Client. func (s *Server) RefreshDiagnostics() error { if ptrIsTrue(s.initializeParams.Capabilities.Workspace.Diagnostics.RefreshSupport) { - if err := s.sendRequest(lsproto.MethodWorkspaceDiagnosticRefresh, nil); err != nil { + if _, err := s.sendRequest(lsproto.MethodWorkspaceDiagnosticRefresh, nil); err != nil { return fmt.Errorf("failed to refresh diagnostics: %w", err) } } @@ -276,7 +263,7 @@ func (s *Server) dispatchLoop() { if req.ID != nil { var cancel context.CancelFunc - ctx, cancel = context.WithCancel(ctx) + ctx, cancel = context.WithCancel(core.WithRequestID(ctx, req.ID.String())) s.pendingClientRequestsMu.Lock() s.pendingClientRequests[*req.ID] = pendingClientRequest{ req: req, @@ -285,9 +272,21 @@ func (s *Server) dispatchLoop() { s.pendingClientRequestsMu.Unlock() } - if err := s.handleRequestOrNotification(ctx, req); err != nil { - s.fatalErrChan <- err - return + handle := func() { + if err := s.handleRequestOrNotification(ctx, req); err != nil { + s.fatalErrChan <- err + } + if req.ID != nil { + s.pendingClientRequestsMu.Lock() + delete(s.pendingClientRequests, *req.ID) + s.pendingClientRequestsMu.Unlock() + } + } + + if isBlockingMethod(req.Method) { + handle() + } else { + go handle() } } } @@ -306,7 +305,7 @@ func (s *Server) writeLoop() { } } -func (s *Server) sendRequest(method lsproto.Method, params any) (<-chan *lsproto.ResponseMessage, error) { +func (s *Server) sendRequest(method lsproto.Method, params any) (any, error) { s.clientSeq++ id := lsproto.NewIDString(fmt.Sprintf("ts%d", s.clientSeq)) req := lsproto.NewRequestMessage(method, id, params) @@ -317,16 +316,13 @@ func (s *Server) sendRequest(method lsproto.Method, params any) (<-chan *lsproto s.pendingServerRequestsMu.Unlock() s.outgoingQueue <- req.Message() - return responseChan, nil -} -func (s *Server) sendNotification(method lsproto.Method, params any) error { - req := lsproto.NewRequestMessage(method, nil /*id*/, params) - data, err := json.Marshal(req) - if err != nil { - return err + // TODO: timeout? + resp := <-responseChan + if resp.Error != nil { + return nil, fmt.Errorf("request failed: %s", resp.Error.String()) } - return s.w.Write(data) + return resp.Result, nil } func (s *Server) sendResult(id *lsproto.ID, result any) { @@ -460,10 +456,6 @@ func (s *Server) handleInitialized(ctx context.Context, req *lsproto.RequestMess PositionEncoding: s.positionEncoding, }) - s.converters = ls.NewConverters(s.positionEncoding, func(fileName string) ls.ScriptInfo { - return s.projectService.GetScriptInfo(fileName) - }) - return nil } @@ -475,34 +467,7 @@ func (s *Server) handleDidOpen(ctx context.Context, req *lsproto.RequestMessage) func (s *Server) handleDidChange(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidChangeTextDocumentParams) - scriptInfo := s.projectService.GetScriptInfo(ls.DocumentURIToFileName(params.TextDocument.Uri)) - if scriptInfo == nil { - s.sendError(req.ID, lsproto.ErrRequestFailed) - return nil - } - - changes := make([]ls.TextChange, len(params.ContentChanges)) - for i, change := range params.ContentChanges { - if partialChange := change.TextDocumentContentChangePartial; partialChange != nil { - if textChange, err := s.converters.FromLSPTextChange(partialChange, scriptInfo.FileName()); err != nil { - s.sendError(req.ID, err) - return nil - } else { - changes[i] = textChange - } - } else if wholeChange := change.TextDocumentContentChangeWholeDocument; wholeChange != nil { - changes[i] = ls.TextChange{ - TextRange: core.NewTextRange(0, len(scriptInfo.Text())), - NewText: wholeChange.Text, - } - } else { - s.sendError(req.ID, lsproto.ErrInvalidRequest) - return nil - } - } - - s.projectService.ChangeFile(ls.DocumentURIToFileName(params.TextDocument.Uri), changes) - return nil + return s.projectService.ChangeFile(params.TextDocument, params.ContentChanges) } func (s *Server) handleDidSave(ctx context.Context, req *lsproto.RequestMessage) error { @@ -524,83 +489,48 @@ func (s *Server) handleDidChangeWatchedFiles(ctx context.Context, req *lsproto.R func (s *Server) handleDocumentDiagnostic(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DocumentDiagnosticParams) - file, project := s.getFileAndProject(params.TextDocument.Uri) - diagnostics := project.LanguageService().GetDocumentDiagnostics(file.FileName()) - lspDiagnostics := make([]*lsproto.Diagnostic, len(diagnostics)) - for i, diag := range diagnostics { - if lspDiagnostic, err := s.converters.ToLSPDiagnostic(diag); err != nil { - s.sendError(req.ID, err) - return nil - } else { - lspDiagnostics[i] = lspDiagnostic - } + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + diagnostics, err := languageService.GetDocumentDiagnostics(params.TextDocument.Uri) + if err != nil { + return err } - s.sendResult(req.ID, &lsproto.DocumentDiagnosticReport{ - RelatedFullDocumentDiagnosticReport: &lsproto.RelatedFullDocumentDiagnosticReport{ - FullDocumentDiagnosticReport: lsproto.FullDocumentDiagnosticReport{ - Kind: lsproto.StringLiteralFull{}, - Items: lspDiagnostics, - }, - }, - }) + s.sendResult(req.ID, diagnostics) return nil } func (s *Server) handleHover(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.HoverParams) - file, project := s.getFileAndProject(params.TextDocument.Uri) - pos, err := s.converters.LineAndCharacterToPositionForFile(params.Position, file.FileName()) + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + hover, err := languageService.ProvideHover(params.TextDocument.Uri, params.Position) if err != nil { - s.sendError(req.ID, err) - return nil + return err } - - hoverText := project.LanguageService().ProvideHover(file.FileName(), pos) - s.sendResult(req.ID, &lsproto.Hover{ - Contents: lsproto.MarkupContentOrMarkedStringOrMarkedStrings{ - MarkupContent: &lsproto.MarkupContent{ - Kind: lsproto.MarkupKindMarkdown, - Value: codeFence("ts", hoverText), - }, - }, - }) - + s.sendResult(req.ID, hover) return nil } func (s *Server) handleDefinition(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DefinitionParams) - file, project := s.getFileAndProject(params.TextDocument.Uri) - pos, err := s.converters.LineAndCharacterToPositionForFile(params.Position, file.FileName()) + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + definition, err := languageService.ProvideDefinition(params.TextDocument.Uri, params.Position) if err != nil { - s.sendError(req.ID, err) - return nil - } - - locations := project.LanguageService().ProvideDefinitions(file.FileName(), pos) - lspLocations := make([]lsproto.Location, len(locations)) - for i, loc := range locations { - if lspLocation, err := s.converters.ToLSPLocation(loc); err != nil { - s.sendError(req.ID, err) - return nil - } else { - lspLocations[i] = lspLocation - } + return err } - - s.sendResult(req.ID, &lsproto.Definition{Locations: &lspLocations}) + s.sendResult(req.ID, definition) return nil } func (s *Server) handleCompletion(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.CompletionParams) - file, project := s.getFileAndProject(params.TextDocument.Uri) - pos, err := s.converters.LineAndCharacterToPositionForFile(params.Position, file.FileName()) - if err != nil { - s.sendError(req.ID, err) - return nil - } - + project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() // !!! remove this after completions is fully ported/tested defer func() { if r := recover(); r != nil { @@ -610,46 +540,36 @@ func (s *Server) handleCompletion(ctx context.Context, req *lsproto.RequestMessa } }() // !!! get user preferences - list := project.LanguageService().ProvideCompletion( - file.FileName(), - pos, + list, err := languageService.ProvideCompletion( + params.TextDocument.Uri, + params.Position, params.Context, s.initializeParams.Capabilities.TextDocument.Completion, &ls.UserPreferences{}) + + if err != nil { + return err + } s.sendResult(req.ID, list) return nil } -func (s *Server) getFileAndProject(uri lsproto.DocumentUri) (*project.ScriptInfo, *project.Project) { - fileName := ls.DocumentURIToFileName(uri) - return s.projectService.EnsureDefaultProjectForFile(fileName) -} - func (s *Server) Log(msg ...any) { fmt.Fprintln(s.stderr, msg...) } -func codeFence(lang string, code string) string { - if code == "" { - return "" - } - ticks := 3 - for strings.Contains(code, strings.Repeat("`", ticks)) { - ticks++ - } - var result strings.Builder - result.Grow(len(code) + len(lang) + 2*ticks + 2) - for range ticks { - result.WriteByte('`') - } - result.WriteString(lang) - result.WriteByte('\n') - result.WriteString(code) - result.WriteByte('\n') - for range ticks { - result.WriteByte('`') +func isBlockingMethod(method lsproto.Method) bool { + switch method { + case lsproto.MethodInitialize, + lsproto.MethodInitialized, + lsproto.MethodTextDocumentDidOpen, + lsproto.MethodTextDocumentDidChange, + lsproto.MethodTextDocumentDidSave, + lsproto.MethodTextDocumentDidClose, + lsproto.MethodWorkspaceDidChangeWatchedFiles: + return true } - return result.String() + return false } func ptrTo[T any](v T) *T { diff --git a/internal/project/checkerpool.go b/internal/project/checkerpool.go new file mode 100644 index 0000000000..829f53db2a --- /dev/null +++ b/internal/project/checkerpool.go @@ -0,0 +1,163 @@ +package project + +import ( + "context" + "iter" + "sync" + + "github.com/microsoft/typescript-go/internal/ast" + "github.com/microsoft/typescript-go/internal/checker" + "github.com/microsoft/typescript-go/internal/compiler" + "github.com/microsoft/typescript-go/internal/core" +) + +type CheckerPool struct { + maxCheckers int + program *compiler.Program + + mu sync.Mutex + cond *sync.Cond + createCheckersOnce sync.Once + checkers []*checker.Checker + inUse map[*checker.Checker]bool + fileAssociations map[*ast.SourceFile]*checker.Checker + requestAssociations map[string]*checker.Checker +} + +var _ compiler.CheckerPool = (*CheckerPool)(nil) + +func newCheckerPool(maxCheckers int, program *compiler.Program) *CheckerPool { + pool := &CheckerPool{ + program: program, + maxCheckers: maxCheckers, + checkers: make([]*checker.Checker, 0, maxCheckers), + inUse: make(map[*checker.Checker]bool), + requestAssociations: make(map[string]*checker.Checker), + } + + pool.cond = sync.NewCond(&pool.mu) + return pool +} + +func (p *CheckerPool) GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) { + p.mu.Lock() + defer p.mu.Unlock() + + requestID := core.GetRequestID(ctx) + if requestID != "" { + if checker, ok := p.requestAssociations[requestID]; ok { + if inUse := p.inUse[checker]; !inUse { + p.inUse[checker] = true + return checker, p.createRelease(requestID, checker) + } + return checker, noop + } + } + + if p.fileAssociations == nil { + p.fileAssociations = make(map[*ast.SourceFile]*checker.Checker) + } + + if checker, ok := p.fileAssociations[file]; ok { + if inUse := p.inUse[checker]; !inUse { + p.inUse[checker] = true + if requestID != "" { + p.requestAssociations[requestID] = checker + } + return checker, p.createRelease(requestID, checker) + } + } + + checker, release := p.getCheckerLocked(requestID) + p.fileAssociations[file] = checker + return checker, release +} + +func (p *CheckerPool) GetChecker(ctx context.Context) (*checker.Checker, func()) { + p.mu.Lock() + defer p.mu.Unlock() + return p.getCheckerLocked(core.GetRequestID(ctx)) +} + +func (p *CheckerPool) Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] { + panic("unimplemented") +} + +func (p *CheckerPool) GetAllCheckers(ctx context.Context) ([]*checker.Checker, func()) { + requestID := core.GetRequestID(ctx) + if requestID == "" { + panic("cannot call GetAllCheckers on a project.checkerPool without a request ID") + } + + // A request can only access one checker + c, release := p.GetChecker(ctx) + return []*checker.Checker{c}, release +} + +func (p *CheckerPool) getCheckerLocked(requestID string) (*checker.Checker, func()) { + if checker := p.getImmediatelyAvailableChecker(); checker != nil { + p.inUse[checker] = true + if requestID != "" { + p.requestAssociations[requestID] = checker + } + return checker, p.createRelease(requestID, checker) + } + + if len(p.checkers) < p.maxCheckers { + checker := p.createCheckerLocked() + p.inUse[checker] = true + if requestID != "" { + p.requestAssociations[requestID] = checker + } + return checker, p.createRelease(requestID, checker) + } + + checker := p.waitForAvailableChecker() + p.inUse[checker] = true + if requestID != "" { + p.requestAssociations[requestID] = checker + } + return checker, p.createRelease(requestID, checker) +} + +func (p *CheckerPool) getImmediatelyAvailableChecker() *checker.Checker { + if len(p.checkers) == 0 { + return nil + } + + for _, checker := range p.checkers { + if inUse := p.inUse[checker]; !inUse { + return checker + } + } + + return nil +} + +func (p *CheckerPool) waitForAvailableChecker() *checker.Checker { + for { + p.cond.Wait() + checker := p.getImmediatelyAvailableChecker() + if checker != nil { + return checker + } + } +} + +func (p *CheckerPool) createRelease(requestId string, checker *checker.Checker) func() { + return func() { + p.mu.Lock() + defer p.mu.Unlock() + + p.inUse[checker] = false + p.cond.Signal() + } +} + +func (p *CheckerPool) createCheckerLocked() *checker.Checker { + checker := checker.NewChecker(p.program) + p.checkers = append(p.checkers, checker) + return checker +} + +func noop() {} diff --git a/internal/project/project.go b/internal/project/project.go index ff49305911..917ee69a6c 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -1,6 +1,7 @@ package project import ( + "context" "fmt" "maps" "slices" @@ -23,8 +24,6 @@ const hr = "-----------------------------------------------" var projectNamer = &namer{} -var _ ls.Host = (*Project)(nil) - type Kind int const ( @@ -34,6 +33,34 @@ const ( KindAuxiliary ) +type snapshot struct { + project *Project + positionEncoding lsproto.PositionEncodingKind + program *compiler.Program +} + +// GetLineMap implements ls.Host. +func (s *snapshot) GetLineMap(fileName string) *ls.LineMap { + file := s.program.GetSourceFile(fileName) + scriptInfo := s.project.host.GetScriptInfoByPath(file.Path()) + if file.Version == scriptInfo.Version() { + return scriptInfo.LineMap() + } + return ls.ComputeLineStarts(file.Text()) +} + +// GetPositionEncoding implements ls.Host. +func (s *snapshot) GetPositionEncoding() lsproto.PositionEncodingKind { + return s.positionEncoding +} + +// GetProgram implements ls.Host. +func (s *snapshot) GetProgram() *compiler.Program { + return s.program +} + +var _ ls.Host = (*snapshot)(nil) + type PendingReload int const ( @@ -84,7 +111,6 @@ type Project struct { rootFileNames *collections.OrderedMap[tspath.Path, string] compilerOptions *core.CompilerOptions parsedCommandLine *tsoptions.ParsedCommandLine - languageService *ls.LanguageService program *compiler.Program // Watchers @@ -134,7 +160,6 @@ func NewProject(name string, kind Kind, currentDirectory string, host ProjectHos return slices.Sorted(maps.Values(data)) }) } - project.languageService = ls.NewLanguageService(project) project.markAsDirty() return project } @@ -207,11 +232,6 @@ func (p *Project) GetDefaultLibraryPath() string { return p.host.DefaultLibraryPath() } -// GetScriptInfo implements ls.Host. -func (p *Project) GetScriptInfo(fileName string) ls.ScriptInfo { - return p.host.GetScriptInfoByPath(p.toPath(fileName)) -} - // GetPositionEncoding implements ls.Host. func (p *Project) GetPositionEncoding() lsproto.PositionEncodingKind { return p.host.PositionEncoding() @@ -233,8 +253,26 @@ func (p *Project) CurrentProgram() *compiler.Program { return p.program } +func (p *Project) GetLanguageServiceForRequest(ctx context.Context) (*ls.LanguageService, func()) { + if core.GetRequestID(ctx) == "" { + panic("context must already have a request ID") + } + snapshot := &snapshot{ + project: p, + positionEncoding: p.host.PositionEncoding(), + program: p.GetProgram(), + } + languageService := ls.NewLanguageService(ctx, snapshot) + return languageService, languageService.Dispose +} + func (p *Project) LanguageService() *ls.LanguageService { - return p.languageService + snapshot := &snapshot{ + project: p, + positionEncoding: p.host.PositionEncoding(), + program: p.GetProgram(), + } + return ls.NewLanguageService(nil /*context*/, snapshot) } func (p *Project) getRootFileWatchGlobs() []string { @@ -422,12 +460,15 @@ func (p *Project) updateProgram() { rootFileNames := p.GetRootFileNames() compilerOptions := p.GetCompilerOptions() - p.program = compiler.NewProgram(compiler.ProgramOptions{ - RootFiles: rootFileNames, - Host: p, - Options: compilerOptions, + var program compiler.Program + program = *compiler.NewProgram(compiler.ProgramOptions{ + RootFiles: rootFileNames, + Host: p, + Options: compilerOptions, + CheckerPool: newCheckerPool(4, &program), }) + p.program = &program p.program.BindSourceFiles() } diff --git a/internal/project/scriptinfo.go b/internal/project/scriptinfo.go index dfa18a67af..a33dcbd297 100644 --- a/internal/project/scriptinfo.go +++ b/internal/project/scriptinfo.go @@ -9,7 +9,7 @@ import ( "github.com/microsoft/typescript-go/internal/vfs" ) -var _ ls.ScriptInfo = (*ScriptInfo)(nil) +var _ ls.Script = (*ScriptInfo)(nil) type ScriptInfo struct { fileName string diff --git a/internal/project/service.go b/internal/project/service.go index d97cbf1bfe..730acf1b50 100644 --- a/internal/project/service.go +++ b/internal/project/service.go @@ -85,8 +85,8 @@ func NewService(host ServiceHost, options ServiceOptions) *Service { realpathToScriptInfos: make(map[tspath.Path]map[*ScriptInfo]struct{}), } - service.converters = ls.NewConverters(options.PositionEncoding, func(fileName string) ls.ScriptInfo { - return service.GetScriptInfo(fileName) + service.converters = ls.NewConverters(options.PositionEncoding, func(fileName string) *ls.LineMap { + return service.GetScriptInfo(fileName).LineMap() }) return service @@ -177,13 +177,30 @@ func (s *Service) OpenFile(fileName string, fileContent string, scriptKind core. s.printProjects() } -func (s *Service) ChangeFile(fileName string, changes []ls.TextChange) { +func (s *Service) ChangeFile(document lsproto.VersionedTextDocumentIdentifier, changes []lsproto.TextDocumentContentChangeEvent) error { + fileName := ls.DocumentURIToFileName(document.Uri) path := s.toPath(fileName) - info := s.GetScriptInfoByPath(path) - if info == nil { - panic("scriptInfo not found") + scriptInfo := s.GetScriptInfoByPath(path) + if scriptInfo == nil { + return fmt.Errorf("file %s not found", fileName) + } + + textChanges := make([]ls.TextChange, len(changes)) + for i, change := range changes { + if partialChange := change.TextDocumentContentChangePartial; partialChange != nil { + textChanges[i] = s.converters.FromLSPTextChange(scriptInfo, partialChange) + } else if wholeChange := change.TextDocumentContentChangeWholeDocument; wholeChange != nil { + textChanges[i] = ls.TextChange{ + TextRange: core.NewTextRange(0, len(scriptInfo.Text())), + NewText: wholeChange.Text, + } + } else { + return fmt.Errorf("invalid change type") + } } - s.applyChangesToFile(info, changes) + + s.applyChangesToFile(scriptInfo, textChanges) + return nil } func (s *Service) CloseFile(fileName string) { @@ -208,6 +225,11 @@ func (s *Service) MarkFileSaved(fileName string, text string) { } } +func (s *Service) EnsureDefaultProjectForURI(url lsproto.DocumentUri) *Project { + _, project := s.EnsureDefaultProjectForFile(ls.DocumentURIToFileName(url)) + return project +} + func (s *Service) EnsureDefaultProjectForFile(fileName string) (*ScriptInfo, *Project) { path := s.toPath(fileName) if info := s.GetScriptInfoByPath(path); info != nil && !info.isOrphan() { diff --git a/internal/project/service_test.go b/internal/project/service_test.go index 3658ea7ebb..476eee0ae2 100644 --- a/internal/project/service_test.go +++ b/internal/project/service_test.go @@ -7,7 +7,6 @@ import ( "github.com/microsoft/typescript-go/internal/bundled" "github.com/microsoft/typescript-go/internal/core" - "github.com/microsoft/typescript-go/internal/ls" "github.com/microsoft/typescript-go/internal/lsp/lsproto" "github.com/microsoft/typescript-go/internal/project" "github.com/microsoft/typescript-go/internal/testutil/projecttestutil" @@ -94,7 +93,31 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/x.ts", defaultFiles["/home/projects/TS/p1/src/x.ts"], core.ScriptKindTS, "") info, proj := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/x.ts") programBefore := proj.GetProgram() - service.ChangeFile("/home/projects/TS/p1/src/x.ts", []ls.TextChange{{TextRange: core.NewTextRange(17, 18), NewText: "2"}}) + service.ChangeFile( + lsproto.VersionedTextDocumentIdentifier{ + TextDocumentIdentifier: lsproto.TextDocumentIdentifier{ + Uri: "file:///home/projects/TS/p1/src/x.ts", + }, + Version: 1, + }, + []lsproto.TextDocumentContentChangeEvent{ + lsproto.TextDocumentContentChangePartialOrTextDocumentContentChangeWholeDocument{ + TextDocumentContentChangePartial: ptrTo(lsproto.TextDocumentContentChangePartial{ + Range: lsproto.Range{ + Start: lsproto.Position{ + Line: 0, + Character: 17, + }, + End: lsproto.Position{ + Line: 0, + Character: 18, + }, + }, + Text: "2", + }), + }, + }, + ) assert.Equal(t, info.Text(), "export const x = 2;") assert.Equal(t, proj.CurrentProgram(), programBefore) assert.Equal(t, programBefore.GetSourceFile("/home/projects/TS/p1/src/x.ts").Text(), "export const x = 1;") @@ -108,7 +131,31 @@ func TestService(t *testing.T) { _, proj := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/x.ts") programBefore := proj.GetProgram() indexFileBefore := programBefore.GetSourceFile("/home/projects/TS/p1/src/index.ts") - service.ChangeFile("/home/projects/TS/p1/src/x.ts", nil) + service.ChangeFile( + lsproto.VersionedTextDocumentIdentifier{ + TextDocumentIdentifier: lsproto.TextDocumentIdentifier{ + Uri: "file:///home/projects/TS/p1/src/x.ts", + }, + Version: 1, + }, + []lsproto.TextDocumentContentChangeEvent{ + lsproto.TextDocumentContentChangePartialOrTextDocumentContentChangeWholeDocument{ + TextDocumentContentChangePartial: ptrTo(lsproto.TextDocumentContentChangePartial{ + Range: lsproto.Range{ + Start: lsproto.Position{ + Line: 0, + Character: 0, + }, + End: lsproto.Position{ + Line: 0, + Character: 0, + }, + }, + Text: ";", + }), + }, + }, + ) assert.Equal(t, proj.GetProgram().GetSourceFile("/home/projects/TS/p1/src/index.ts"), indexFileBefore) }) @@ -120,7 +167,31 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/index.ts", files["/home/projects/TS/p1/src/index.ts"], core.ScriptKindTS, "") assert.Check(t, service.GetScriptInfo("/home/projects/TS/p1/y.ts") == nil) - service.ChangeFile("/home/projects/TS/p1/src/index.ts", []ls.TextChange{{TextRange: core.NewTextRange(0, 0), NewText: `import { y } from "../y";\n`}}) + service.ChangeFile( + lsproto.VersionedTextDocumentIdentifier{ + TextDocumentIdentifier: lsproto.TextDocumentIdentifier{ + Uri: "file:///home/projects/TS/p1/src/index.ts", + }, + Version: 1, + }, + []lsproto.TextDocumentContentChangeEvent{ + lsproto.TextDocumentContentChangePartialOrTextDocumentContentChangeWholeDocument{ + TextDocumentContentChangePartial: ptrTo(lsproto.TextDocumentContentChangePartial{ + Range: lsproto.Range{ + Start: lsproto.Position{ + Line: 0, + Character: 0, + }, + End: lsproto.Position{ + Line: 0, + Character: 0, + }, + }, + Text: `import { y } from "../y";\n`, + }), + }, + }, + ) service.EnsureDefaultProjectForFile("/home/projects/TS/p1/y.ts") }) }) diff --git a/internal/testrunner/compiler_runner.go b/internal/testrunner/compiler_runner.go index 5f6c48091b..f09a987296 100644 --- a/internal/testrunner/compiler_runner.go +++ b/internal/testrunner/compiler_runner.go @@ -455,7 +455,9 @@ func createHarnessTestFile(unit *testUnit, currentDirectory string) *harnessutil func (c *compilerTest) verifyUnionOrdering(t *testing.T) { t.Run("union ordering", func(t *testing.T) { - for _, c := range c.result.Program.GetTypeCheckers() { + checkers, done := c.result.Program.GetTypeCheckers(t.Context()) + defer done() + for _, c := range checkers { for union := range c.UnionTypes() { types := union.Types() diff --git a/internal/testutil/harnessutil/harnessutil.go b/internal/testutil/harnessutil/harnessutil.go index 71526477f7..3f44747893 100644 --- a/internal/testutil/harnessutil/harnessutil.go +++ b/internal/testutil/harnessutil/harnessutil.go @@ -537,11 +537,12 @@ func compileFilesWithHost( // ...ts.filter(longerErrors!, p => !ts.some(shorterErrors, p2 => ts.compareDiagnostics(p, p2) === ts.Comparison.EqualTo)), // ), // ] : postErrors; + ctx := context.Background() program := createProgram(host, options, rootFiles) var diagnostics []*ast.Diagnostic - diagnostics = append(diagnostics, program.GetSyntacticDiagnostics(context.Background(), nil)...) - diagnostics = append(diagnostics, program.GetSemanticDiagnostics(context.Background(), nil)...) - diagnostics = append(diagnostics, program.GetGlobalDiagnostics()...) + diagnostics = append(diagnostics, program.GetSyntacticDiagnostics(ctx, nil)...) + diagnostics = append(diagnostics, program.GetSemanticDiagnostics(ctx, nil)...) + diagnostics = append(diagnostics, program.GetGlobalDiagnostics(ctx)...) emitResult := program.Emit(compiler.EmitOptions{}) return newCompilationResult(options, program, emitResult, diagnostics, harnessOptions) diff --git a/internal/testutil/lstestutil/lstestutil.go b/internal/testutil/lstestutil/lstestutil.go index fc81641246..7929ab4972 100644 --- a/internal/testutil/lstestutil/lstestutil.go +++ b/internal/testutil/lstestutil/lstestutil.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/microsoft/typescript-go/internal/core" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) type markerRange struct { @@ -15,9 +16,10 @@ type markerRange struct { } type Marker struct { - Filename string - Position int - Name string + Filename string + Position int + LSPosition lsproto.Position + Name string } type TestData struct { @@ -174,7 +176,11 @@ func recordMarker( marker := &Marker{ Filename: filename, Position: location.position, - Name: name, + LSPosition: lsproto.Position{ + Line: uint32(location.sourceLine - 1), + Character: uint32(location.sourceColumn - 1), + }, + Name: name, } // Verify markers for uniqueness if _, ok := markerMap[name]; ok { diff --git a/internal/testutil/tsbaseline/type_symbol_baseline.go b/internal/testutil/tsbaseline/type_symbol_baseline.go index ef4336490d..3cc5b44c9d 100644 --- a/internal/testutil/tsbaseline/type_symbol_baseline.go +++ b/internal/testutil/tsbaseline/type_symbol_baseline.go @@ -1,6 +1,7 @@ package tsbaseline import ( + "context" "fmt" "regexp" "slices" @@ -263,10 +264,10 @@ func newTypeWriterWalker(program *compiler.Program, hadErrorBaseline bool) *type } } -func (walker *typeWriterWalker) getTypeCheckerForCurrentFile() *checker.Checker { +func (walker *typeWriterWalker) getTypeCheckerForCurrentFile() (*checker.Checker, func()) { // If we don't use the right checker for the file, its contents won't be up to date // since the types/symbols baselines appear to depend on files having been checked. - return walker.program.GetTypeCheckerForFile(walker.currentSourceFile) + return walker.program.GetTypeCheckerForFile(context.Background(), walker.currentSourceFile) } type typeWriterResult struct { @@ -332,7 +333,8 @@ func (walker *typeWriterWalker) writeTypeOrSymbol(node *ast.Node, isSymbolWalk b actualPos := scanner.SkipTrivia(walker.currentSourceFile.Text(), node.Pos()) line, _ := scanner.GetLineAndCharacterOfPosition(walker.currentSourceFile, actualPos) sourceText := scanner.GetSourceTextOfNodeFromSourceFile(walker.currentSourceFile, node, false /*includeTrivia*/) - fileChecker := walker.getTypeCheckerForCurrentFile() + fileChecker, done := walker.getTypeCheckerForCurrentFile() + defer done() if !isSymbolWalk { // Don't try to get the type of something that's already a type. From 3ef2f90c8dae932465344b874a544d1f516dfab0 Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Wed, 14 May 2025 12:38:02 -0700 Subject: [PATCH 04/12] Remove canceled checkers from the pool --- internal/checker/exports.go | 4 + internal/project/checkerpool.go | 125 ++++++++++++++++++++------------ 2 files changed, 81 insertions(+), 48 deletions(-) diff --git a/internal/checker/exports.go b/internal/checker/exports.go index 4bf67fbba0..d04b7a97cc 100644 --- a/internal/checker/exports.go +++ b/internal/checker/exports.go @@ -48,3 +48,7 @@ func (c *Checker) GetTypeOfPropertyOfContextualType(t *Type, name string) *Type func GetDeclarationModifierFlagsFromSymbol(s *ast.Symbol) ast.ModifierFlags { return getDeclarationModifierFlagsFromSymbol(s) } + +func (c *Checker) WasCanceled() bool { + return c.wasCanceled +} diff --git a/internal/project/checkerpool.go b/internal/project/checkerpool.go index 829f53db2a..e2bb313cdb 100644 --- a/internal/project/checkerpool.go +++ b/internal/project/checkerpool.go @@ -20,8 +20,8 @@ type CheckerPool struct { createCheckersOnce sync.Once checkers []*checker.Checker inUse map[*checker.Checker]bool - fileAssociations map[*ast.SourceFile]*checker.Checker - requestAssociations map[string]*checker.Checker + fileAssociations map[*ast.SourceFile]int + requestAssociations map[string]int } var _ compiler.CheckerPool = (*CheckerPool)(nil) @@ -30,9 +30,9 @@ func newCheckerPool(maxCheckers int, program *compiler.Program) *CheckerPool { pool := &CheckerPool{ program: program, maxCheckers: maxCheckers, - checkers: make([]*checker.Checker, 0, maxCheckers), + checkers: make([]*checker.Checker, maxCheckers), inUse: make(map[*checker.Checker]bool), - requestAssociations: make(map[string]*checker.Checker), + requestAssociations: make(map[string]int), } pool.cond = sync.NewCond(&pool.mu) @@ -45,38 +45,47 @@ func (p *CheckerPool) GetCheckerForFile(ctx context.Context, file *ast.SourceFil requestID := core.GetRequestID(ctx) if requestID != "" { - if checker, ok := p.requestAssociations[requestID]; ok { - if inUse := p.inUse[checker]; !inUse { - p.inUse[checker] = true - return checker, p.createRelease(requestID, checker) + if index, ok := p.requestAssociations[requestID]; ok { + checker := p.checkers[index] + if checker != nil { + if inUse := p.inUse[checker]; !inUse { + p.inUse[checker] = true + return checker, p.createRelease(requestID, index, checker) + } + // Checker is in use, but by the same request - assume it's the + // same goroutine or is managing its own synchronization + return checker, noop } - return checker, noop } } if p.fileAssociations == nil { - p.fileAssociations = make(map[*ast.SourceFile]*checker.Checker) + p.fileAssociations = make(map[*ast.SourceFile]int) } - if checker, ok := p.fileAssociations[file]; ok { - if inUse := p.inUse[checker]; !inUse { - p.inUse[checker] = true - if requestID != "" { - p.requestAssociations[requestID] = checker + if index, ok := p.fileAssociations[file]; ok { + checker := p.checkers[index] + if checker != nil { + if inUse := p.inUse[checker]; !inUse { + p.inUse[checker] = true + if requestID != "" { + p.requestAssociations[requestID] = index + } + return checker, p.createRelease(requestID, index, checker) } - return checker, p.createRelease(requestID, checker) } } - checker, release := p.getCheckerLocked(requestID) - p.fileAssociations[file] = checker - return checker, release + checker, index := p.getCheckerLocked(requestID) + p.fileAssociations[file] = index + return checker, p.createRelease(requestID, index, checker) } func (p *CheckerPool) GetChecker(ctx context.Context) (*checker.Checker, func()) { p.mu.Lock() defer p.mu.Unlock() - return p.getCheckerLocked(core.GetRequestID(ctx)) + checker, index := p.getCheckerLocked(core.GetRequestID(ctx)) + return checker, p.createRelease(core.GetRequestID(ctx), index, checker) } func (p *CheckerPool) Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] { @@ -94,70 +103,90 @@ func (p *CheckerPool) GetAllCheckers(ctx context.Context) ([]*checker.Checker, f return []*checker.Checker{c}, release } -func (p *CheckerPool) getCheckerLocked(requestID string) (*checker.Checker, func()) { - if checker := p.getImmediatelyAvailableChecker(); checker != nil { +func (p *CheckerPool) getCheckerLocked(requestID string) (*checker.Checker, int) { + if checker, index := p.getImmediatelyAvailableChecker(); checker != nil { p.inUse[checker] = true if requestID != "" { - p.requestAssociations[requestID] = checker + p.requestAssociations[requestID] = index } - return checker, p.createRelease(requestID, checker) + return checker, index } - if len(p.checkers) < p.maxCheckers { - checker := p.createCheckerLocked() + if !p.isFullLocked() { + checker, index := p.createCheckerLocked() p.inUse[checker] = true if requestID != "" { - p.requestAssociations[requestID] = checker + p.requestAssociations[requestID] = index } - return checker, p.createRelease(requestID, checker) + return checker, index } - checker := p.waitForAvailableChecker() + checker, index := p.waitForAvailableChecker() p.inUse[checker] = true if requestID != "" { - p.requestAssociations[requestID] = checker + p.requestAssociations[requestID] = index } - return checker, p.createRelease(requestID, checker) + return checker, index } -func (p *CheckerPool) getImmediatelyAvailableChecker() *checker.Checker { - if len(p.checkers) == 0 { - return nil - } - - for _, checker := range p.checkers { +func (p *CheckerPool) getImmediatelyAvailableChecker() (*checker.Checker, int) { + for i, checker := range p.checkers { + if checker == nil { + continue + } if inUse := p.inUse[checker]; !inUse { - return checker + return checker, i } } - return nil + return nil, -1 } -func (p *CheckerPool) waitForAvailableChecker() *checker.Checker { +func (p *CheckerPool) waitForAvailableChecker() (*checker.Checker, int) { for { p.cond.Wait() - checker := p.getImmediatelyAvailableChecker() + checker, index := p.getImmediatelyAvailableChecker() if checker != nil { - return checker + return checker, index } } } -func (p *CheckerPool) createRelease(requestId string, checker *checker.Checker) func() { +func (p *CheckerPool) createRelease(requestId string, index int, checker *checker.Checker) func() { return func() { p.mu.Lock() defer p.mu.Unlock() - p.inUse[checker] = false + delete(p.requestAssociations, requestId) + if checker.WasCanceled() { + // Canceled checkers must be disposed + p.checkers[index] = nil + delete(p.inUse, checker) + } else { + p.inUse[checker] = false + } p.cond.Signal() } } -func (p *CheckerPool) createCheckerLocked() *checker.Checker { - checker := checker.NewChecker(p.program) - p.checkers = append(p.checkers, checker) - return checker +func (p *CheckerPool) isFullLocked() bool { + for _, checker := range p.checkers { + if checker == nil { + return false + } + } + return true +} + +func (p *CheckerPool) createCheckerLocked() (*checker.Checker, int) { + for i, existing := range p.checkers { + if existing == nil { + checker := checker.NewChecker(p.program) + p.checkers[i] = checker + return checker, i + } + } + panic("called createCheckerLocked when pool is full") } func noop() {} From afc5138dfe79838d1e39904bb55f92dabf698d4a Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Thu, 15 May 2025 16:22:30 -0700 Subject: [PATCH 05/12] Thread context through everything correctly --- internal/lsp/server.go | 172 ++++++++++-------- internal/project/host.go | 8 +- internal/project/project.go | 12 +- internal/project/service.go | 5 +- internal/project/service_test.go | 16 +- internal/project/watch.go | 7 +- .../projecttestutil/clientmock_generated.go | 46 +++-- 7 files changed, 157 insertions(+), 109 deletions(-) diff --git a/internal/lsp/server.go b/internal/lsp/server.go index fc88cd5bb8..136b3fb8ca 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -6,15 +6,19 @@ import ( "errors" "fmt" "io" + "os" + "os/signal" "runtime/debug" "slices" "sync" + "syscall" "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/ls" "github.com/microsoft/typescript-go/internal/lsp/lsproto" "github.com/microsoft/typescript-go/internal/project" "github.com/microsoft/typescript-go/internal/vfs" + "golang.org/x/sync/errgroup" ) type ServerOptions struct { @@ -36,7 +40,6 @@ func NewServer(opts *ServerOptions) *Server { r: lsproto.NewBaseReader(opts.In), w: lsproto.NewBaseWriter(opts.Out), stderr: opts.Err, - fatalErrChan: make(chan error, 1), requestQueue: make(chan *lsproto.RequestMessage, 100), outgoingQueue: make(chan *lsproto.Message, 100), pendingClientRequests: make(map[lsproto.ID]pendingClientRequest), @@ -65,7 +68,6 @@ type Server struct { stderr io.Writer clientSeq int32 - fatalErrChan chan error requestQueue chan *lsproto.RequestMessage outgoingQueue chan *lsproto.Message pendingClientRequests map[lsproto.ID]pendingClientRequest @@ -122,9 +124,9 @@ func (s *Server) Client() project.Client { } // WatchFiles implements project.Client. -func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { +func (s *Server) WatchFiles(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { watcherId := fmt.Sprintf("watcher-%d", s.watcherID) - _, err := s.sendRequest(lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ + _, err := s.sendRequest(ctx, lsproto.MethodClientRegisterCapability, &lsproto.RegistrationParams{ Registrations: []*lsproto.Registration{ { Id: watcherId, @@ -147,9 +149,9 @@ func (s *Server) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.Watc } // UnwatchFiles implements project.Client. -func (s *Server) UnwatchFiles(handle project.WatcherHandle) error { +func (s *Server) UnwatchFiles(ctx context.Context, handle project.WatcherHandle) error { if s.watchers.Has(handle) { - _, err := s.sendRequest(lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{ + _, err := s.sendRequest(ctx, lsproto.MethodClientUnregisterCapability, &lsproto.UnregistrationParams{ Unregisterations: []*lsproto.Unregistration{ { Id: string(handle), @@ -170,9 +172,9 @@ func (s *Server) UnwatchFiles(handle project.WatcherHandle) error { } // RefreshDiagnostics implements project.Client. -func (s *Server) RefreshDiagnostics() error { +func (s *Server) RefreshDiagnostics(ctx context.Context) error { if ptrIsTrue(s.initializeParams.Capabilities.Workspace.Diagnostics.RefreshSupport) { - if _, err := s.sendRequest(lsproto.MethodWorkspaceDiagnosticRefresh, nil); err != nil { + if _, err := s.sendRequest(ctx, lsproto.MethodWorkspaceDiagnosticRefresh, nil); err != nil { return fmt.Errorf("failed to refresh diagnostics: %w", err) } } @@ -180,27 +182,28 @@ func (s *Server) RefreshDiagnostics() error { } func (s *Server) Run() error { - go s.dispatchLoop() - go s.writeLoop() - go s.readLoop() - err := <-s.fatalErrChan - return err + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { return s.dispatchLoop(ctx) }) + g.Go(func() error { return s.writeLoop(ctx) }) + g.Go(func() error { return s.readLoop(ctx) }) + return g.Wait() } -func (s *Server) readLoop() { +func (s *Server) readLoop(ctx context.Context) error { for { msg, err := s.read() if err != nil { - if errors.Is(err, io.EOF) { - s.fatalErrChan <- nil - return + if err == io.EOF { + return nil } if errors.Is(err, lsproto.ErrInvalidRequest) { s.sendError(nil, err) continue } - s.fatalErrChan <- err - return + return err } if s.initializeParams == nil && msg.Kind == lsproto.MessageKindRequest { @@ -225,7 +228,7 @@ func (s *Server) readLoop() { } else { req := msg.AsRequest() if req.Method == lsproto.MethodCancelRequest { - go s.cancelRequest(req.Params.(*lsproto.CancelParams).Id) + s.cancelRequest(req.Params.(*lsproto.CancelParams).Id) } else { s.requestQueue <- req } @@ -257,55 +260,67 @@ func (s *Server) read() (*lsproto.Message, error) { return req, nil } -func (s *Server) dispatchLoop() { - for req := range s.requestQueue { - ctx := context.Background() - - if req.ID != nil { - var cancel context.CancelFunc - ctx, cancel = context.WithCancel(core.WithRequestID(ctx, req.ID.String())) - s.pendingClientRequestsMu.Lock() - s.pendingClientRequests[*req.ID] = pendingClientRequest{ - req: req, - cancel: cancel, - } - s.pendingClientRequestsMu.Unlock() - } - - handle := func() { - if err := s.handleRequestOrNotification(ctx, req); err != nil { - s.fatalErrChan <- err - } +func (s *Server) dispatchLoop(ctx context.Context) error { + ctx, lspExit := context.WithCancel(ctx) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-s.requestQueue: if req.ID != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(core.WithRequestID(ctx, req.ID.String())) s.pendingClientRequestsMu.Lock() - delete(s.pendingClientRequests, *req.ID) + s.pendingClientRequests[*req.ID] = pendingClientRequest{ + req: req, + cancel: cancel, + } s.pendingClientRequestsMu.Unlock() } - } - if isBlockingMethod(req.Method) { - handle() - } else { - go handle() + handle := func() { + if err := s.handleRequestOrNotification(ctx, req); err != nil { + if err == io.EOF { + lspExit() + } else { + s.sendError(req.ID, err) + } + } + + if req.ID != nil { + s.pendingClientRequestsMu.Lock() + delete(s.pendingClientRequests, *req.ID) + s.pendingClientRequestsMu.Unlock() + } + } + + if isBlockingMethod(req.Method) { + handle() + } else { + go handle() + } } } } -func (s *Server) writeLoop() { - for msg := range s.outgoingQueue { - data, err := json.Marshal(msg) - if err != nil { - s.fatalErrChan <- fmt.Errorf("failed to marshal message: %w", err) - continue - } - if err := s.w.Write(data); err != nil { - s.fatalErrChan <- fmt.Errorf("failed to write message: %w", err) - continue +func (s *Server) writeLoop(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case msg := <-s.outgoingQueue: + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + if err := s.w.Write(data); err != nil { + return fmt.Errorf("failed to write message: %w", err) + } } } } -func (s *Server) sendRequest(method lsproto.Method, params any) (any, error) { +func (s *Server) sendRequest(ctx context.Context, method lsproto.Method, params any) (any, error) { s.clientSeq++ id := lsproto.NewIDString(fmt.Sprintf("ts%d", s.clientSeq)) req := lsproto.NewRequestMessage(method, id, params) @@ -317,12 +332,21 @@ func (s *Server) sendRequest(method lsproto.Method, params any) (any, error) { s.outgoingQueue <- req.Message() - // TODO: timeout? - resp := <-responseChan - if resp.Error != nil { - return nil, fmt.Errorf("request failed: %s", resp.Error.String()) + select { + case <-ctx.Done(): + s.pendingServerRequestsMu.Lock() + defer s.pendingServerRequestsMu.Unlock() + if respChan, ok := s.pendingServerRequests[*id]; ok { + close(respChan) + delete(s.pendingServerRequests, *id) + } + return nil, ctx.Err() + case resp := <-responseChan: + if resp.Error != nil { + return nil, fmt.Errorf("request failed: %s", resp.Error.String()) + } + return resp.Result, nil } - return resp.Result, nil } func (s *Server) sendResult(id *lsproto.ID, result any) { @@ -358,25 +382,25 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R s.sendError(req.ID, lsproto.ErrInvalidRequest) return nil case *lsproto.InitializedParams: - s.handleInitialized(ctx, req) + return s.handleInitialized(ctx, req) case *lsproto.DidOpenTextDocumentParams: - s.handleDidOpen(ctx, req) + return s.handleDidOpen(ctx, req) case *lsproto.DidChangeTextDocumentParams: - s.handleDidChange(ctx, req) + return s.handleDidChange(ctx, req) case *lsproto.DidSaveTextDocumentParams: - s.handleDidSave(ctx, req) + return s.handleDidSave(ctx, req) case *lsproto.DidCloseTextDocumentParams: - s.handleDidClose(ctx, req) + return s.handleDidClose(ctx, req) case *lsproto.DidChangeWatchedFilesParams: - s.handleDidChangeWatchedFiles(ctx, req) + return s.handleDidChangeWatchedFiles(ctx, req) case *lsproto.DocumentDiagnosticParams: - s.handleDocumentDiagnostic(ctx, req) + return s.handleDocumentDiagnostic(ctx, req) case *lsproto.HoverParams: - s.handleHover(ctx, req) + return s.handleHover(ctx, req) case *lsproto.DefinitionParams: - s.handleDefinition(ctx, req) + return s.handleDefinition(ctx, req) case *lsproto.CompletionParams: - s.handleCompletion(ctx, req) + return s.handleCompletion(ctx, req) default: switch req.Method { case lsproto.MethodShutdown: @@ -384,8 +408,7 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R s.sendResult(req.ID, nil) return nil case lsproto.MethodExit: - s.fatalErrChan <- nil - return nil + return io.EOF default: s.Log("unknown method", req.Method) if req.ID != nil { @@ -394,7 +417,6 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R return nil } } - return nil } func (s *Server) handleInitialize(req *lsproto.RequestMessage) { @@ -484,7 +506,7 @@ func (s *Server) handleDidClose(ctx context.Context, req *lsproto.RequestMessage func (s *Server) handleDidChangeWatchedFiles(ctx context.Context, req *lsproto.RequestMessage) error { params := req.Params.(*lsproto.DidChangeWatchedFilesParams) - return s.projectService.OnWatchedFilesChanged(params.Changes) + return s.projectService.OnWatchedFilesChanged(ctx, params.Changes) } func (s *Server) handleDocumentDiagnostic(ctx context.Context, req *lsproto.RequestMessage) error { diff --git a/internal/project/host.go b/internal/project/host.go index b9e9635e26..18bad0b808 100644 --- a/internal/project/host.go +++ b/internal/project/host.go @@ -1,6 +1,8 @@ package project import ( + "context" + "github.com/microsoft/typescript-go/internal/lsp/lsproto" "github.com/microsoft/typescript-go/internal/vfs" ) @@ -8,9 +10,9 @@ import ( type WatcherHandle string type Client interface { - WatchFiles(watchers []*lsproto.FileSystemWatcher) (WatcherHandle, error) - UnwatchFiles(handle WatcherHandle) error - RefreshDiagnostics() error + WatchFiles(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (WatcherHandle, error) + UnwatchFiles(ctx context.Context, handle WatcherHandle) error + RefreshDiagnostics(ctx context.Context) error } type ServiceHost interface { diff --git a/internal/project/project.go b/internal/project/project.go index 917ee69a6c..bb8c92ced9 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -313,7 +313,7 @@ func (p *Project) getModuleResolutionWatchGlobs() (failedLookups map[tspath.Path return failedLookups, affectingLocaions } -func (p *Project) updateWatchers() { +func (p *Project) updateWatchers(ctx context.Context) { client := p.host.Client() if !p.host.IsWatchEnabled() || client == nil { return @@ -323,20 +323,20 @@ func (p *Project) updateWatchers() { failedLookupGlobs, affectingLocationGlobs := p.getModuleResolutionWatchGlobs() if rootFileGlobs != nil { - if updated, err := p.rootFilesWatch.update(rootFileGlobs); err != nil { + if updated, err := p.rootFilesWatch.update(ctx, rootFileGlobs); err != nil { p.log(fmt.Sprintf("Failed to update root file watch: %v", err)) } else if updated { p.log("Root file watches updated:\n" + formatFileList(rootFileGlobs, "\t", hr)) } } - if updated, err := p.failedLookupsWatch.update(failedLookupGlobs); err != nil { + if updated, err := p.failedLookupsWatch.update(ctx, failedLookupGlobs); err != nil { p.log(fmt.Sprintf("Failed to update failed lookup watch: %v", err)) } else if updated { p.log("Failed lookup watches updated:\n" + formatFileList(p.failedLookupsWatch.globs, "\t", hr)) } - if updated, err := p.affectingLocationsWatch.update(affectingLocationGlobs); err != nil { + if updated, err := p.affectingLocationsWatch.update(ctx, affectingLocationGlobs); err != nil { p.log(fmt.Sprintf("Failed to update affecting location watch: %v", err)) } else if updated { p.log("Affecting location watches updated:\n" + formatFileList(p.affectingLocationsWatch.globs, "\t", hr)) @@ -452,7 +452,9 @@ func (p *Project) updateGraph() bool { } } - p.updateWatchers() + // TODO: this is currently always synchronously called by some kind of updating request, + // but in Strada we throttle, so at least sometimes this should be considered top-level? + p.updateWatchers(context.TODO()) return true } diff --git a/internal/project/service.go b/internal/project/service.go index 730acf1b50..c2b75c081e 100644 --- a/internal/project/service.go +++ b/internal/project/service.go @@ -1,6 +1,7 @@ package project import ( + "context" "fmt" "strings" "sync" @@ -255,7 +256,7 @@ func (s *Service) SourceFileCount() int { return s.documentRegistry.size() } -func (s *Service) OnWatchedFilesChanged(changes []*lsproto.FileEvent) error { +func (s *Service) OnWatchedFilesChanged(ctx context.Context, changes []*lsproto.FileEvent) error { for _, change := range changes { fileName := ls.DocumentURIToFileName(change.Uri) path := s.toPath(fileName) @@ -286,7 +287,7 @@ func (s *Service) OnWatchedFilesChanged(changes []*lsproto.FileEvent) error { client := s.host.Client() if client != nil { - return client.RefreshDiagnostics() + return client.RefreshDiagnostics(ctx) } return nil diff --git a/internal/project/service_test.go b/internal/project/service_test.go index 476eee0ae2..cc0471dec4 100644 --- a/internal/project/service_test.go +++ b/internal/project/service_test.go @@ -316,7 +316,7 @@ func TestService(t *testing.T) { files := maps.Clone(defaultFiles) files["/home/projects/TS/p1/src/x.ts"] = `export const x = 2;` host.ReplaceFS(files) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeChanged, Uri: "file:///home/projects/TS/p1/src/x.ts", @@ -336,7 +336,7 @@ func TestService(t *testing.T) { files := maps.Clone(defaultFiles) files["/home/projects/TS/p1/src/x.ts"] = `export const x = 2;` host.ReplaceFS(files) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeChanged, Uri: "file:///home/projects/TS/p1/src/x.ts", @@ -375,7 +375,7 @@ func TestService(t *testing.T) { } }` host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeChanged, Uri: "file:///home/projects/TS/p1/tsconfig.json", @@ -407,7 +407,7 @@ func TestService(t *testing.T) { filesCopy := maps.Clone(files) delete(filesCopy, "/home/projects/TS/p1/src/x.ts") host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeDeleted, Uri: "file:///home/projects/TS/p1/src/x.ts", @@ -440,7 +440,7 @@ func TestService(t *testing.T) { filesCopy := maps.Clone(files) delete(filesCopy, "/home/projects/TS/p1/src/index.ts") host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeDeleted, Uri: "file:///home/projects/TS/p1/src/index.ts", @@ -496,7 +496,7 @@ func TestService(t *testing.T) { filesCopy := maps.Clone(files) filesCopy["/home/projects/TS/p1/src/y.ts"] = `export const y = 1;` host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeCreated, Uri: "file:///home/projects/TS/p1/src/y.ts", @@ -537,7 +537,7 @@ func TestService(t *testing.T) { filesCopy := maps.Clone(files) filesCopy["/home/projects/TS/p1/src/z.ts"] = `export const z = 1;` host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeCreated, Uri: "file:///home/projects/TS/p1/src/z.ts", @@ -573,7 +573,7 @@ func TestService(t *testing.T) { filesCopy := maps.Clone(files) filesCopy["/home/projects/TS/p1/src/a.ts"] = `const a = 1;` host.ReplaceFS(filesCopy) - assert.NilError(t, service.OnWatchedFilesChanged([]*lsproto.FileEvent{ + assert.NilError(t, service.OnWatchedFilesChanged(t.Context(), []*lsproto.FileEvent{ { Type: lsproto.FileChangeTypeCreated, Uri: "file:///home/projects/TS/p1/src/a.ts", diff --git a/internal/project/watch.go b/internal/project/watch.go index 18db3a8c91..adf32a4dbb 100644 --- a/internal/project/watch.go +++ b/internal/project/watch.go @@ -1,6 +1,7 @@ package project import ( + "context" "slices" "github.com/microsoft/typescript-go/internal/lsp/lsproto" @@ -29,7 +30,7 @@ func newWatchedFiles[T any](client Client, watchKind lsproto.WatchKind, getGlobs } } -func (w *watchedFiles[T]) update(newData T) (updated bool, err error) { +func (w *watchedFiles[T]) update(ctx context.Context, newData T) (updated bool, err error) { newGlobs := w.getGlobs(newData) w.data = newData if slices.Equal(w.globs, newGlobs) { @@ -38,7 +39,7 @@ func (w *watchedFiles[T]) update(newData T) (updated bool, err error) { w.globs = newGlobs if w.watcherID != "" { - if err = w.client.UnwatchFiles(w.watcherID); err != nil { + if err = w.client.UnwatchFiles(ctx, w.watcherID); err != nil { return false, err } } @@ -52,7 +53,7 @@ func (w *watchedFiles[T]) update(newData T) (updated bool, err error) { Kind: &w.watchKind, }) } - watcherID, err := w.client.WatchFiles(watchers) + watcherID, err := w.client.WatchFiles(ctx, watchers) if err != nil { return false, err } diff --git a/internal/testutil/projecttestutil/clientmock_generated.go b/internal/testutil/projecttestutil/clientmock_generated.go index e8bf1ab779..a19433ac8a 100644 --- a/internal/testutil/projecttestutil/clientmock_generated.go +++ b/internal/testutil/projecttestutil/clientmock_generated.go @@ -4,6 +4,7 @@ package projecttestutil import ( + "context" "sync" "github.com/microsoft/typescript-go/internal/lsp/lsproto" @@ -20,13 +21,13 @@ var _ project.Client = &ClientMock{} // // // make and configure a mocked project.Client // mockedClient := &ClientMock{ -// RefreshDiagnosticsFunc: func() error { +// RefreshDiagnosticsFunc: func(ctx context.Context) error { // panic("mock out the RefreshDiagnostics method") // }, -// UnwatchFilesFunc: func(handle project.WatcherHandle) error { +// UnwatchFilesFunc: func(ctx context.Context, handle project.WatcherHandle) error { // panic("mock out the UnwatchFiles method") // }, -// WatchFilesFunc: func(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { +// WatchFilesFunc: func(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { // panic("mock out the WatchFiles method") // }, // } @@ -37,26 +38,32 @@ var _ project.Client = &ClientMock{} // } type ClientMock struct { // RefreshDiagnosticsFunc mocks the RefreshDiagnostics method. - RefreshDiagnosticsFunc func() error + RefreshDiagnosticsFunc func(ctx context.Context) error // UnwatchFilesFunc mocks the UnwatchFiles method. - UnwatchFilesFunc func(handle project.WatcherHandle) error + UnwatchFilesFunc func(ctx context.Context, handle project.WatcherHandle) error // WatchFilesFunc mocks the WatchFiles method. - WatchFilesFunc func(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) + WatchFilesFunc func(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) // calls tracks calls to the methods. calls struct { // RefreshDiagnostics holds details about calls to the RefreshDiagnostics method. RefreshDiagnostics []struct { + // Ctx is the ctx argument value. + Ctx context.Context } // UnwatchFiles holds details about calls to the UnwatchFiles method. UnwatchFiles []struct { + // Ctx is the ctx argument value. + Ctx context.Context // Handle is the handle argument value. Handle project.WatcherHandle } // WatchFiles holds details about calls to the WatchFiles method. WatchFiles []struct { + // Ctx is the ctx argument value. + Ctx context.Context // Watchers is the watchers argument value. Watchers []*lsproto.FileSystemWatcher } @@ -67,9 +74,12 @@ type ClientMock struct { } // RefreshDiagnostics calls RefreshDiagnosticsFunc. -func (mock *ClientMock) RefreshDiagnostics() error { +func (mock *ClientMock) RefreshDiagnostics(ctx context.Context) error { callInfo := struct { - }{} + Ctx context.Context + }{ + Ctx: ctx, + } mock.lockRefreshDiagnostics.Lock() mock.calls.RefreshDiagnostics = append(mock.calls.RefreshDiagnostics, callInfo) mock.lockRefreshDiagnostics.Unlock() @@ -79,7 +89,7 @@ func (mock *ClientMock) RefreshDiagnostics() error { ) return errOut } - return mock.RefreshDiagnosticsFunc() + return mock.RefreshDiagnosticsFunc(ctx) } // RefreshDiagnosticsCalls gets all the calls that were made to RefreshDiagnostics. @@ -87,8 +97,10 @@ func (mock *ClientMock) RefreshDiagnostics() error { // // len(mockedClient.RefreshDiagnosticsCalls()) func (mock *ClientMock) RefreshDiagnosticsCalls() []struct { + Ctx context.Context } { var calls []struct { + Ctx context.Context } mock.lockRefreshDiagnostics.RLock() calls = mock.calls.RefreshDiagnostics @@ -97,10 +109,12 @@ func (mock *ClientMock) RefreshDiagnosticsCalls() []struct { } // UnwatchFiles calls UnwatchFilesFunc. -func (mock *ClientMock) UnwatchFiles(handle project.WatcherHandle) error { +func (mock *ClientMock) UnwatchFiles(ctx context.Context, handle project.WatcherHandle) error { callInfo := struct { + Ctx context.Context Handle project.WatcherHandle }{ + Ctx: ctx, Handle: handle, } mock.lockUnwatchFiles.Lock() @@ -112,7 +126,7 @@ func (mock *ClientMock) UnwatchFiles(handle project.WatcherHandle) error { ) return errOut } - return mock.UnwatchFilesFunc(handle) + return mock.UnwatchFilesFunc(ctx, handle) } // UnwatchFilesCalls gets all the calls that were made to UnwatchFiles. @@ -120,9 +134,11 @@ func (mock *ClientMock) UnwatchFiles(handle project.WatcherHandle) error { // // len(mockedClient.UnwatchFilesCalls()) func (mock *ClientMock) UnwatchFilesCalls() []struct { + Ctx context.Context Handle project.WatcherHandle } { var calls []struct { + Ctx context.Context Handle project.WatcherHandle } mock.lockUnwatchFiles.RLock() @@ -132,10 +148,12 @@ func (mock *ClientMock) UnwatchFilesCalls() []struct { } // WatchFiles calls WatchFilesFunc. -func (mock *ClientMock) WatchFiles(watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { +func (mock *ClientMock) WatchFiles(ctx context.Context, watchers []*lsproto.FileSystemWatcher) (project.WatcherHandle, error) { callInfo := struct { + Ctx context.Context Watchers []*lsproto.FileSystemWatcher }{ + Ctx: ctx, Watchers: watchers, } mock.lockWatchFiles.Lock() @@ -148,7 +166,7 @@ func (mock *ClientMock) WatchFiles(watchers []*lsproto.FileSystemWatcher) (proje ) return watcherHandleOut, errOut } - return mock.WatchFilesFunc(watchers) + return mock.WatchFilesFunc(ctx, watchers) } // WatchFilesCalls gets all the calls that were made to WatchFiles. @@ -156,9 +174,11 @@ func (mock *ClientMock) WatchFiles(watchers []*lsproto.FileSystemWatcher) (proje // // len(mockedClient.WatchFilesCalls()) func (mock *ClientMock) WatchFilesCalls() []struct { + Ctx context.Context Watchers []*lsproto.FileSystemWatcher } { var calls []struct { + Ctx context.Context Watchers []*lsproto.FileSystemWatcher } mock.lockWatchFiles.RLock() From 92f3ea2d923f97f9f285ded3dca9eba40ad8114e Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Thu, 15 May 2025 16:26:09 -0700 Subject: [PATCH 06/12] Use checkerPool.Files --- internal/compiler/program.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/compiler/program.go b/internal/compiler/program.go index 18540f24be..cc57a3b13d 100644 --- a/internal/compiler/program.go +++ b/internal/compiler/program.go @@ -218,10 +218,10 @@ func (p *Program) CheckSourceFiles(ctx context.Context) { wg := core.NewWorkGroup(p.singleThreaded()) checkers, done := p.checkerPool.GetAllCheckers(ctx) defer done() - for index, checker := range checkers { + for _, checker := range checkers { wg.Queue(func() { - for i := index; i < len(p.files); i += len(checkers) { - checker.CheckSourceFile(ctx, p.files[i]) + for file := range p.checkerPool.Files(checker) { + checker.CheckSourceFile(ctx, file) } }) } From 64928436299b4f24dbb7cb9ee30d896fbcaac9f0 Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Thu, 15 May 2025 19:35:23 -0700 Subject: [PATCH 07/12] More threading context, fixing bugs --- internal/api/api.go | 33 +++++---- internal/api/server.go | 5 +- internal/compiler/program.go | 7 +- internal/ls/api.go | 20 +++-- internal/ls/completions.go | 20 ++++- internal/ls/completions_test.go | 10 ++- internal/ls/definition.go | 9 ++- internal/ls/diagnostics.go | 9 ++- internal/ls/hover.go | 9 ++- internal/ls/languageservice.go | 37 +--------- internal/lsp/server.go | 9 ++- internal/project/checkerpool.go | 74 ++++++++++++------- internal/project/project.go | 34 ++++----- internal/project/service_test.go | 24 +++--- internal/testutil/lstestutil/lstestutil.go | 39 +++++++--- .../projecttestutil/projecttestutil.go | 6 ++ 16 files changed, 205 insertions(+), 140 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index 42191d6915..6f8e16d6bc 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "errors" "fmt" @@ -129,7 +130,7 @@ func (api *API) IsWatchEnabled() bool { return false } -func (api *API) HandleRequest(id int, method string, payload []byte) ([]byte, error) { +func (api *API) HandleRequest(ctx context.Context, method string, payload []byte) ([]byte, error) { params, err := unmarshalPayload(method, payload) if err != nil { return nil, err @@ -155,27 +156,27 @@ func (api *API) HandleRequest(id int, method string, payload []byte) ([]byte, er return encodeJSON(api.LoadProject(params.(*LoadProjectParams).ConfigFileName)) case MethodGetSymbolAtPosition: params := params.(*GetSymbolAtPositionParams) - return encodeJSON(api.GetSymbolAtPosition(params.Project, params.FileName, int(params.Position))) + return encodeJSON(api.GetSymbolAtPosition(ctx, params.Project, params.FileName, int(params.Position))) case MethodGetSymbolsAtPositions: params := params.(*GetSymbolsAtPositionsParams) return encodeJSON(core.TryMap(params.Positions, func(position uint32) (any, error) { - return api.GetSymbolAtPosition(params.Project, params.FileName, int(position)) + return api.GetSymbolAtPosition(ctx, params.Project, params.FileName, int(position)) })) case MethodGetSymbolAtLocation: params := params.(*GetSymbolAtLocationParams) - return encodeJSON(api.GetSymbolAtLocation(params.Project, params.Location)) + return encodeJSON(api.GetSymbolAtLocation(ctx, params.Project, params.Location)) case MethodGetSymbolsAtLocations: params := params.(*GetSymbolsAtLocationsParams) return encodeJSON(core.TryMap(params.Locations, func(location Handle[ast.Node]) (any, error) { - return api.GetSymbolAtLocation(params.Project, location) + return api.GetSymbolAtLocation(ctx, params.Project, location) })) case MethodGetTypeOfSymbol: params := params.(*GetTypeOfSymbolParams) - return encodeJSON(api.GetTypeOfSymbol(params.Project, params.Symbol)) + return encodeJSON(api.GetTypeOfSymbol(ctx, params.Project, params.Symbol)) case MethodGetTypesOfSymbols: params := params.(*GetTypesOfSymbolsParams) return encodeJSON(core.TryMap(params.Symbols, func(symbol Handle[ast.Symbol]) (any, error) { - return api.GetTypeOfSymbol(params.Project, symbol) + return api.GetTypeOfSymbol(ctx, params.Project, symbol) })) default: return nil, fmt.Errorf("unhandled API method %q", method) @@ -223,12 +224,14 @@ func (api *API) LoadProject(configFileName string) (*ProjectResponse, error) { return data, nil } -func (api *API) GetSymbolAtPosition(projectId Handle[project.Project], fileName string, position int) (*SymbolResponse, error) { +func (api *API) GetSymbolAtPosition(ctx context.Context, projectId Handle[project.Project], fileName string, position int) (*SymbolResponse, error) { project, ok := api.projects[projectId] if !ok { return nil, errors.New("project not found") } - symbol, err := project.LanguageService().GetSymbolAtPosition(fileName, position) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + symbol, err := languageService.GetSymbolAtPosition(ctx, fileName, position) if err != nil || symbol == nil { return nil, err } @@ -239,7 +242,7 @@ func (api *API) GetSymbolAtPosition(projectId Handle[project.Project], fileName return data, nil } -func (api *API) GetSymbolAtLocation(projectId Handle[project.Project], location Handle[ast.Node]) (*SymbolResponse, error) { +func (api *API) GetSymbolAtLocation(ctx context.Context, projectId Handle[project.Project], location Handle[ast.Node]) (*SymbolResponse, error) { project, ok := api.projects[projectId] if !ok { return nil, errors.New("project not found") @@ -262,7 +265,9 @@ func (api *API) GetSymbolAtLocation(projectId Handle[project.Project], location if node == nil { return nil, fmt.Errorf("node of kind %s not found at position %d in file %q", kind.String(), pos, sourceFile.FileName()) } - symbol := project.LanguageService().GetSymbolAtLocation(node) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + symbol := languageService.GetSymbolAtLocation(ctx, node) if symbol == nil { return nil, nil } @@ -273,7 +278,7 @@ func (api *API) GetSymbolAtLocation(projectId Handle[project.Project], location return data, nil } -func (api *API) GetTypeOfSymbol(projectId Handle[project.Project], symbolHandle Handle[ast.Symbol]) (*TypeResponse, error) { +func (api *API) GetTypeOfSymbol(ctx context.Context, projectId Handle[project.Project], symbolHandle Handle[ast.Symbol]) (*TypeResponse, error) { project, ok := api.projects[projectId] if !ok { return nil, errors.New("project not found") @@ -284,7 +289,9 @@ func (api *API) GetTypeOfSymbol(projectId Handle[project.Project], symbolHandle if !ok { return nil, fmt.Errorf("symbol %q not found", symbolHandle) } - t := project.LanguageService().GetTypeOfSymbol(symbol) + languageService, done := project.GetLanguageServiceForRequest(ctx) + defer done() + t := languageService.GetTypeOfSymbol(ctx, symbol) if t == nil { return nil, nil } diff --git a/internal/api/server.go b/internal/api/server.go index 2b60dd855a..061e5d7bae 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -2,13 +2,16 @@ package api import ( "bufio" + "context" "encoding/binary" "encoding/json" "fmt" "io" + "strconv" "sync" "github.com/microsoft/typescript-go/internal/bundled" + "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/project" "github.com/microsoft/typescript-go/internal/vfs" "github.com/microsoft/typescript-go/internal/vfs/osvfs" @@ -254,7 +257,7 @@ func (s *Server) handleRequest(method string, payload []byte) ([]byte, error) { case "echo": return payload, nil default: - return s.api.HandleRequest(s.requestId, method, payload) + return s.api.HandleRequest(core.WithRequestID(context.Background(), strconv.Itoa(s.requestId)), method, payload) } } diff --git a/internal/compiler/program.go b/internal/compiler/program.go index cc57a3b13d..d05b3fffa8 100644 --- a/internal/compiler/program.go +++ b/internal/compiler/program.go @@ -27,7 +27,7 @@ type ProgramOptions struct { SingleThreaded core.Tristate ProjectReference []core.ProjectReference ConfigFileParsingDiagnostics []*ast.Diagnostic - CheckerPool CheckerPool + CreateCheckerPool func(*Program) CheckerPool } type Program struct { @@ -78,8 +78,9 @@ func NewProgram(options ProgramOptions) *Program { if p.compilerOptions == nil { p.compilerOptions = &core.CompilerOptions{} } - p.checkerPool = options.CheckerPool - if p.checkerPool == nil { + if p.programOptions.CreateCheckerPool != nil { + p.checkerPool = p.programOptions.CreateCheckerPool(p) + } else { p.checkerPool = newCheckerPool(core.IfElse(p.singleThreaded(), 1, 4), p) } diff --git a/internal/ls/api.go b/internal/ls/api.go index cb0827459d..4393806e72 100644 --- a/internal/ls/api.go +++ b/internal/ls/api.go @@ -1,6 +1,7 @@ package ls import ( + "context" "errors" "fmt" @@ -14,8 +15,8 @@ var ( ErrNoTokenAtPosition = errors.New("no token found at position") ) -func (l *LanguageService) GetSymbolAtPosition(fileName string, position int) (*ast.Symbol, error) { - _, file := l.tryGetProgramAndFile(fileName) +func (l *LanguageService) GetSymbolAtPosition(ctx context.Context, fileName string, position int) (*ast.Symbol, error) { + program, file := l.tryGetProgramAndFile(fileName) if file == nil { return nil, fmt.Errorf("%w: %s", ErrNoSourceFile, fileName) } @@ -23,16 +24,21 @@ func (l *LanguageService) GetSymbolAtPosition(fileName string, position int) (*a if node == nil { return nil, fmt.Errorf("%w: %s:%d", ErrNoTokenAtPosition, fileName, position) } - checker := l.GetTypeChecker(file) + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() return checker.GetSymbolAtLocation(node), nil } -func (l *LanguageService) GetSymbolAtLocation(node *ast.Node) *ast.Symbol { - checker := l.GetTypeChecker(ast.GetSourceFileOfNode(node)) +func (l *LanguageService) GetSymbolAtLocation(ctx context.Context, node *ast.Node) *ast.Symbol { + program := l.GetProgram() + checker, done := program.GetTypeCheckerForFile(ctx, ast.GetSourceFileOfNode(node)) + defer done() return checker.GetSymbolAtLocation(node) } -func (l *LanguageService) GetTypeOfSymbol(symbol *ast.Symbol) *checker.Type { - checker := l.GetTypeChecker(nil /*file*/) +func (l *LanguageService) GetTypeOfSymbol(ctx context.Context, symbol *ast.Symbol) *checker.Type { + program := l.GetProgram() + checker, done := program.GetTypeChecker(ctx) + defer done() return checker.GetTypeOfSymbolAtLocation(symbol, nil) } diff --git a/internal/ls/completions.go b/internal/ls/completions.go index ec4abb82ca..257b4b2d87 100644 --- a/internal/ls/completions.go +++ b/internal/ls/completions.go @@ -1,6 +1,7 @@ package ls import ( + "context" "fmt" "maps" "slices" @@ -22,6 +23,7 @@ import ( ) func (l *LanguageService) ProvideCompletion( + ctx context.Context, documentURI lsproto.DocumentUri, position lsproto.Position, context *lsproto.CompletionContext, @@ -30,6 +32,7 @@ func (l *LanguageService) ProvideCompletion( ) (*lsproto.CompletionList, error) { program, file := l.getProgramAndFile(documentURI) return l.getCompletionsAtPosition( + ctx, program, file, int(l.converters.LineAndCharacterToPosition(file, position)), @@ -264,6 +267,7 @@ const ( ) func (l *LanguageService) getCompletionsAtPosition( + ctx context.Context, program *compiler.Program, file *ast.SourceFile, position int, @@ -294,7 +298,9 @@ func (l *LanguageService) getCompletionsAtPosition( // !!! label completions - data := getCompletionData(program, l.GetTypeChecker(file), file, position, preferences) + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() + data := getCompletionData(program, checker, file, position, preferences) if data == nil { return nil } @@ -302,6 +308,7 @@ func (l *LanguageService) getCompletionsAtPosition( switch data := data.(type) { case *completionDataData: response := l.completionInfoFromData( + ctx, file, program, compilerOptions, @@ -1457,6 +1464,7 @@ func getDefaultCommitCharacters(isNewIdentifierLocation bool) []string { } func (l *LanguageService) completionInfoFromData( + ctx context.Context, file *ast.SourceFile, program *compiler.Program, compilerOptions *core.CompilerOptions, @@ -1494,6 +1502,7 @@ func (l *LanguageService) completionInfoFromData( } uniqueNames, sortedEntries := l.getCompletionEntriesFromSymbols( + ctx, data, nil, /*replacementToken*/ position, @@ -1554,6 +1563,7 @@ func (l *LanguageService) completionInfoFromData( } func (l *LanguageService) getCompletionEntriesFromSymbols( + ctx context.Context, data *completionDataData, replacementToken *ast.Node, position int, @@ -1566,7 +1576,8 @@ func (l *LanguageService) getCompletionEntriesFromSymbols( ) (uniqueNames core.Set[string], sortedEntries []*lsproto.CompletionItem) { closestSymbolDeclaration := getClosestSymbolDeclaration(data.contextToken, data.location) useSemicolons := probablyUsesSemicolons(file) - typeChecker := l.GetTypeChecker(file) + typeChecker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() isMemberCompletion := isMemberCompletionKind(data.completionKind) optionalReplacementSpan := getOptionalReplacementSpan(data.location, file) // Tracks unique names. @@ -1608,6 +1619,7 @@ func (l *LanguageService) getCompletionEntriesFromSymbols( sortText = originalSortText } entry := l.createCompletionItem( + ctx, symbol, sortText, replacementToken, @@ -1675,6 +1687,7 @@ func createCompletionItemForLiteral( } func (l *LanguageService) createCompletionItem( + ctx context.Context, symbol *ast.Symbol, sortText sortText, replacementToken *ast.Node, @@ -1700,7 +1713,8 @@ func (l *LanguageService) createCompletionItem( source := getSourceFromOrigin(origin) var labelDetails *lsproto.CompletionItemLabelDetails - typeChecker := l.GetTypeChecker(file) + typeChecker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() insertQuestionDot := originIsNullableMember(origin) useBraces := originIsSymbolMember(origin) || needsConvertPropertyAccess if originIsThisType(origin) { diff --git a/internal/ls/completions_test.go b/internal/ls/completions_test.go index ad48c2ef3d..3b9b1edc0d 100644 --- a/internal/ls/completions_test.go +++ b/internal/ls/completions_test.go @@ -1,6 +1,7 @@ package ls_test import ( + "context" "slices" "testing" @@ -1551,7 +1552,9 @@ func runTest(t *testing.T, files map[string]string, expected map[string]*testCas parsedFiles[fileName] = content } } - languageService := createLanguageService(mainFileName, parsedFiles) + ctx := projecttestutil.WithRequestID(t.Context()) + languageService, done := createLanguageService(ctx, mainFileName, parsedFiles) + defer done() context := &lsproto.CompletionContext{ TriggerKind: lsproto.CompletionTriggerKindInvoked, } @@ -1576,6 +1579,7 @@ func runTest(t *testing.T, files map[string]string, expected map[string]*testCas t.Fatalf("No marker found for '%s'", markerName) } completionList, err := languageService.ProvideCompletion( + ctx, ls.FileNameToDocumentURI(mainFileName), marker.LSPosition, context, @@ -1611,11 +1615,11 @@ func assertIncludesItem(t *testing.T, actual *lsproto.CompletionList, expected * return false } -func createLanguageService(fileName string, files map[string]string) *ls.LanguageService { +func createLanguageService(ctx context.Context, fileName string, files map[string]string) (*ls.LanguageService, func()) { projectService, _ := projecttestutil.Setup(files) projectService.OpenFile(fileName, files[fileName], core.ScriptKindTS, "") project := projectService.Projects()[0] - return project.LanguageService() + return project.GetLanguageServiceForRequest(ctx) } func ptrTo[T any](v T) *T { diff --git a/internal/ls/definition.go b/internal/ls/definition.go index 9eeed3072d..6ccc155c2e 100644 --- a/internal/ls/definition.go +++ b/internal/ls/definition.go @@ -1,6 +1,8 @@ package ls import ( + "context" + "github.com/microsoft/typescript-go/internal/ast" "github.com/microsoft/typescript-go/internal/astnav" "github.com/microsoft/typescript-go/internal/core" @@ -8,14 +10,15 @@ import ( "github.com/microsoft/typescript-go/internal/scanner" ) -func (l *LanguageService) ProvideDefinition(documentURI lsproto.DocumentUri, position lsproto.Position) (*lsproto.Definition, error) { - file := l.getSourceFile(documentURI) +func (l *LanguageService) ProvideDefinition(ctx context.Context, documentURI lsproto.DocumentUri, position lsproto.Position) (*lsproto.Definition, error) { + program, file := l.getProgramAndFile(documentURI) node := astnav.GetTouchingPropertyName(file, int(l.converters.LineAndCharacterToPosition(file, position))) if node.Kind == ast.KindSourceFile { return nil, nil } - checker := l.GetTypeChecker(file) + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() if symbol := checker.GetSymbolAtLocation(node); symbol != nil { if symbol.Flags&ast.SymbolFlagsAlias != 0 { diff --git a/internal/ls/diagnostics.go b/internal/ls/diagnostics.go index b16a92566c..5bf4a1d1fd 100644 --- a/internal/ls/diagnostics.go +++ b/internal/ls/diagnostics.go @@ -8,9 +8,9 @@ import ( "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -func (l *LanguageService) GetDocumentDiagnostics(documentURI lsproto.DocumentUri) (*lsproto.DocumentDiagnosticReport, error) { +func (l *LanguageService) GetDocumentDiagnostics(ctx context.Context, documentURI lsproto.DocumentUri) (*lsproto.DocumentDiagnosticReport, error) { program, file := l.getProgramAndFile(documentURI) - syntaxDiagnostics := program.GetSyntacticDiagnostics(context.Background(), file) + syntaxDiagnostics := program.GetSyntacticDiagnostics(ctx, file) var lspDiagnostics []*lsproto.Diagnostic if len(syntaxDiagnostics) != 0 { lspDiagnostics = make([]*lsproto.Diagnostic, len(syntaxDiagnostics)) @@ -18,8 +18,9 @@ func (l *LanguageService) GetDocumentDiagnostics(documentURI lsproto.DocumentUri lspDiagnostics[i] = toLSPDiagnostic(diag, l.converters) } } else { - checker := l.GetTypeChecker(file) - semanticDiagnostics := checker.GetDiagnostics(context.Background(), file) + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() + semanticDiagnostics := checker.GetDiagnostics(ctx, file) lspDiagnostics = make([]*lsproto.Diagnostic, len(semanticDiagnostics)) for i, diag := range semanticDiagnostics { lspDiagnostics[i] = toLSPDiagnostic(diag, l.converters) diff --git a/internal/ls/hover.go b/internal/ls/hover.go index fd20995ab1..395148ba47 100644 --- a/internal/ls/hover.go +++ b/internal/ls/hover.go @@ -1,6 +1,7 @@ package ls import ( + "context" "strings" "github.com/microsoft/typescript-go/internal/ast" @@ -8,14 +9,16 @@ import ( "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) -func (l *LanguageService) ProvideHover(documentURI lsproto.DocumentUri, position lsproto.Position) (*lsproto.Hover, error) { - file := l.getSourceFile(documentURI) +func (l *LanguageService) ProvideHover(ctx context.Context, documentURI lsproto.DocumentUri, position lsproto.Position) (*lsproto.Hover, error) { + program, file := l.getProgramAndFile(documentURI) node := astnav.GetTouchingPropertyName(file, int(l.converters.LineAndCharacterToPosition(file, position))) if node.Kind == ast.KindSourceFile { // Avoid giving quickInfo for the sourceFile as a whole. return nil, nil } - result := l.GetTypeChecker(file).GetQuickInfoAtLocation(node) + checker, done := program.GetTypeCheckerForFile(ctx, file) + defer done() + result := checker.GetQuickInfoAtLocation(node) if result != "" { return &lsproto.Hover{ Contents: lsproto.MarkupContentOrMarkedStringOrMarkedStrings{ diff --git a/internal/ls/languageservice.go b/internal/ls/languageservice.go index 900fce3a14..e441a2fee0 100644 --- a/internal/ls/languageservice.go +++ b/internal/ls/languageservice.go @@ -4,21 +4,18 @@ import ( "context" "github.com/microsoft/typescript-go/internal/ast" - "github.com/microsoft/typescript-go/internal/checker" "github.com/microsoft/typescript-go/internal/compiler" "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) type LanguageService struct { - ctx context.Context - host Host - converters *Converters - disposables []func() + ctx context.Context + host Host + converters *Converters } func NewLanguageService(ctx context.Context, host Host) *LanguageService { return &LanguageService{ - ctx: ctx, host: host, converters: NewConverters(host.GetPositionEncoding(), host.GetLineMap), } @@ -29,40 +26,12 @@ func (l *LanguageService) GetProgram() *compiler.Program { return l.host.GetProgram() } -func (l *LanguageService) GetTypeChecker(file *ast.SourceFile) *checker.Checker { - var checker *checker.Checker - var done func() - if file == nil { - checker, done = l.GetProgram().GetTypeChecker(l.ctx) - } else { - checker, done = l.GetProgram().GetTypeCheckerForFile(l.ctx, file) - } - l.disposables = append(l.disposables, done) - return checker -} - -func (l *LanguageService) Dispose() { - for _, dispose := range l.disposables { - dispose() - } - l.disposables = nil -} - func (l *LanguageService) tryGetProgramAndFile(fileName string) (*compiler.Program, *ast.SourceFile) { program := l.GetProgram() file := program.GetSourceFile(fileName) return program, file } -func (l *LanguageService) getSourceFile(documentURI lsproto.DocumentUri) *ast.SourceFile { - fileName := DocumentURIToFileName(documentURI) - _, file := l.tryGetProgramAndFile(fileName) - if file == nil { - return nil - } - return file -} - func (l *LanguageService) getProgramAndFile(documentURI lsproto.DocumentUri) (*compiler.Program, *ast.SourceFile) { fileName := DocumentURIToFileName(documentURI) program, file := l.tryGetProgramAndFile(fileName) diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 136b3fb8ca..13c644240e 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -262,11 +262,13 @@ func (s *Server) read() (*lsproto.Message, error) { func (s *Server) dispatchLoop(ctx context.Context) error { ctx, lspExit := context.WithCancel(ctx) + defer lspExit() for { select { case <-ctx.Done(): return ctx.Err() case req := <-s.requestQueue: + ctx := ctx if req.ID != nil { var cancel context.CancelFunc ctx, cancel = context.WithCancel(core.WithRequestID(ctx, req.ID.String())) @@ -514,7 +516,7 @@ func (s *Server) handleDocumentDiagnostic(ctx context.Context, req *lsproto.Requ project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) languageService, done := project.GetLanguageServiceForRequest(ctx) defer done() - diagnostics, err := languageService.GetDocumentDiagnostics(params.TextDocument.Uri) + diagnostics, err := languageService.GetDocumentDiagnostics(ctx, params.TextDocument.Uri) if err != nil { return err } @@ -527,7 +529,7 @@ func (s *Server) handleHover(ctx context.Context, req *lsproto.RequestMessage) e project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) languageService, done := project.GetLanguageServiceForRequest(ctx) defer done() - hover, err := languageService.ProvideHover(params.TextDocument.Uri, params.Position) + hover, err := languageService.ProvideHover(ctx, params.TextDocument.Uri, params.Position) if err != nil { return err } @@ -540,7 +542,7 @@ func (s *Server) handleDefinition(ctx context.Context, req *lsproto.RequestMessa project := s.projectService.EnsureDefaultProjectForURI(params.TextDocument.Uri) languageService, done := project.GetLanguageServiceForRequest(ctx) defer done() - definition, err := languageService.ProvideDefinition(params.TextDocument.Uri, params.Position) + definition, err := languageService.ProvideDefinition(ctx, params.TextDocument.Uri, params.Position) if err != nil { return err } @@ -563,6 +565,7 @@ func (s *Server) handleCompletion(ctx context.Context, req *lsproto.RequestMessa }() // !!! get user preferences list, err := languageService.ProvideCompletion( + ctx, params.TextDocument.Uri, params.Position, params.Context, diff --git a/internal/project/checkerpool.go b/internal/project/checkerpool.go index e2bb313cdb..c6f9ce4041 100644 --- a/internal/project/checkerpool.go +++ b/internal/project/checkerpool.go @@ -11,7 +11,7 @@ import ( "github.com/microsoft/typescript-go/internal/core" ) -type CheckerPool struct { +type checkerPool struct { maxCheckers int program *compiler.Program @@ -24,10 +24,10 @@ type CheckerPool struct { requestAssociations map[string]int } -var _ compiler.CheckerPool = (*CheckerPool)(nil) +var _ compiler.CheckerPool = (*checkerPool)(nil) -func newCheckerPool(maxCheckers int, program *compiler.Program) *CheckerPool { - pool := &CheckerPool{ +func newCheckerPool(maxCheckers int, program *compiler.Program) *checkerPool { + pool := &checkerPool{ program: program, maxCheckers: maxCheckers, checkers: make([]*checker.Checker, maxCheckers), @@ -39,23 +39,14 @@ func newCheckerPool(maxCheckers int, program *compiler.Program) *CheckerPool { return pool } -func (p *CheckerPool) GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) { +func (p *checkerPool) GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) { p.mu.Lock() defer p.mu.Unlock() requestID := core.GetRequestID(ctx) if requestID != "" { - if index, ok := p.requestAssociations[requestID]; ok { - checker := p.checkers[index] - if checker != nil { - if inUse := p.inUse[checker]; !inUse { - p.inUse[checker] = true - return checker, p.createRelease(requestID, index, checker) - } - // Checker is in use, but by the same request - assume it's the - // same goroutine or is managing its own synchronization - return checker, noop - } + if checker, release := p.getRequestCheckerLocked(requestID); checker != nil { + return checker, release } } @@ -81,29 +72,33 @@ func (p *CheckerPool) GetCheckerForFile(ctx context.Context, file *ast.SourceFil return checker, p.createRelease(requestID, index, checker) } -func (p *CheckerPool) GetChecker(ctx context.Context) (*checker.Checker, func()) { +func (p *checkerPool) GetChecker(ctx context.Context) (*checker.Checker, func()) { p.mu.Lock() defer p.mu.Unlock() checker, index := p.getCheckerLocked(core.GetRequestID(ctx)) return checker, p.createRelease(core.GetRequestID(ctx), index, checker) } -func (p *CheckerPool) Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] { +func (p *checkerPool) Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] { panic("unimplemented") } -func (p *CheckerPool) GetAllCheckers(ctx context.Context) ([]*checker.Checker, func()) { +func (p *checkerPool) GetAllCheckers(ctx context.Context) ([]*checker.Checker, func()) { requestID := core.GetRequestID(ctx) if requestID == "" { panic("cannot call GetAllCheckers on a project.checkerPool without a request ID") } // A request can only access one checker + if c, release := p.getRequestCheckerLocked(requestID); c != nil { + return []*checker.Checker{c}, release + } + c, release := p.GetChecker(ctx) return []*checker.Checker{c}, release } -func (p *CheckerPool) getCheckerLocked(requestID string) (*checker.Checker, int) { +func (p *checkerPool) getCheckerLocked(requestID string) (*checker.Checker, int) { if checker, index := p.getImmediatelyAvailableChecker(); checker != nil { p.inUse[checker] = true if requestID != "" { @@ -129,7 +124,23 @@ func (p *CheckerPool) getCheckerLocked(requestID string) (*checker.Checker, int) return checker, index } -func (p *CheckerPool) getImmediatelyAvailableChecker() (*checker.Checker, int) { +func (p *checkerPool) getRequestCheckerLocked(requestID string) (*checker.Checker, func()) { + if index, ok := p.requestAssociations[requestID]; ok { + checker := p.checkers[index] + if checker != nil { + if inUse := p.inUse[checker]; !inUse { + p.inUse[checker] = true + return checker, p.createRelease(requestID, index, checker) + } + // Checker is in use, but by the same request - assume it's the + // same goroutine or is managing its own synchronization + return checker, noop + } + } + return nil, noop +} + +func (p *checkerPool) getImmediatelyAvailableChecker() (*checker.Checker, int) { for i, checker := range p.checkers { if checker == nil { continue @@ -142,7 +153,7 @@ func (p *CheckerPool) getImmediatelyAvailableChecker() (*checker.Checker, int) { return nil, -1 } -func (p *CheckerPool) waitForAvailableChecker() (*checker.Checker, int) { +func (p *checkerPool) waitForAvailableChecker() (*checker.Checker, int) { for { p.cond.Wait() checker, index := p.getImmediatelyAvailableChecker() @@ -152,7 +163,7 @@ func (p *CheckerPool) waitForAvailableChecker() (*checker.Checker, int) { } } -func (p *CheckerPool) createRelease(requestId string, index int, checker *checker.Checker) func() { +func (p *checkerPool) createRelease(requestId string, index int, checker *checker.Checker) func() { return func() { p.mu.Lock() defer p.mu.Unlock() @@ -169,7 +180,7 @@ func (p *CheckerPool) createRelease(requestId string, index int, checker *checke } } -func (p *CheckerPool) isFullLocked() bool { +func (p *checkerPool) isFullLocked() bool { for _, checker := range p.checkers { if checker == nil { return false @@ -178,7 +189,7 @@ func (p *CheckerPool) isFullLocked() bool { return true } -func (p *CheckerPool) createCheckerLocked() (*checker.Checker, int) { +func (p *checkerPool) createCheckerLocked() (*checker.Checker, int) { for i, existing := range p.checkers { if existing == nil { checker := checker.NewChecker(p.program) @@ -189,4 +200,17 @@ func (p *CheckerPool) createCheckerLocked() (*checker.Checker, int) { panic("called createCheckerLocked when pool is full") } +func (p *checkerPool) isRequestCheckerInUse(requestID string) bool { + p.mu.Lock() + defer p.mu.Unlock() + + if index, ok := p.requestAssociations[requestID]; ok { + checker := p.checkers[index] + if checker != nil { + return p.inUse[checker] + } + } + return false +} + func noop() {} diff --git a/internal/project/project.go b/internal/project/project.go index bb8c92ced9..982c28f4ec 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -112,6 +112,7 @@ type Project struct { compilerOptions *core.CompilerOptions parsedCommandLine *tsoptions.ParsedCommandLine program *compiler.Program + checkerPool *checkerPool // Watchers rootFilesWatch *watchedFiles[[]string] @@ -257,22 +258,20 @@ func (p *Project) GetLanguageServiceForRequest(ctx context.Context) (*ls.Languag if core.GetRequestID(ctx) == "" { panic("context must already have a request ID") } + program := p.GetProgram() + checkerPool := p.checkerPool snapshot := &snapshot{ project: p, positionEncoding: p.host.PositionEncoding(), - program: p.GetProgram(), + program: program, } languageService := ls.NewLanguageService(ctx, snapshot) - return languageService, languageService.Dispose -} - -func (p *Project) LanguageService() *ls.LanguageService { - snapshot := &snapshot{ - project: p, - positionEncoding: p.host.PositionEncoding(), - program: p.GetProgram(), + cleanup := func() { + if checkerPool.isRequestCheckerInUse(core.GetRequestID(ctx)) { + panic(fmt.Errorf("checker for request ID %s not returned to pool at end of request", core.GetRequestID(ctx))) + } } - return ls.NewLanguageService(nil /*context*/, snapshot) + return languageService, cleanup } func (p *Project) getRootFileWatchGlobs() []string { @@ -462,15 +461,16 @@ func (p *Project) updateProgram() { rootFileNames := p.GetRootFileNames() compilerOptions := p.GetCompilerOptions() - var program compiler.Program - program = *compiler.NewProgram(compiler.ProgramOptions{ - RootFiles: rootFileNames, - Host: p, - Options: compilerOptions, - CheckerPool: newCheckerPool(4, &program), + p.program = compiler.NewProgram(compiler.ProgramOptions{ + RootFiles: rootFileNames, + Host: p, + Options: compilerOptions, + CreateCheckerPool: func(program *compiler.Program) compiler.CheckerPool { + p.checkerPool = newCheckerPool(4, program) + return p.checkerPool + }, }) - p.program = &program p.program.BindSourceFiles() } diff --git a/internal/project/service_test.go b/internal/project/service_test.go index cc0471dec4..af403ad93b 100644 --- a/internal/project/service_test.go +++ b/internal/project/service_test.go @@ -365,7 +365,7 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/index.ts", files["/home/projects/TS/p1/src/index.ts"], core.ScriptKindTS, "") _, project := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/index.ts") program := project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) filesCopy := maps.Clone(files) filesCopy["/home/projects/TS/p1/tsconfig.json"] = `{ @@ -383,7 +383,7 @@ func TestService(t *testing.T) { })) program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) }) t.Run("delete explicitly included file", func(t *testing.T) { @@ -402,7 +402,7 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/index.ts", files["/home/projects/TS/p1/src/index.ts"], core.ScriptKindTS, "") _, project := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/index.ts") program := project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) filesCopy := maps.Clone(files) delete(filesCopy, "/home/projects/TS/p1/src/x.ts") @@ -415,7 +415,7 @@ func TestService(t *testing.T) { })) program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) assert.Check(t, program.GetSourceFile("/home/projects/TS/p1/src/x.ts") == nil) }) @@ -435,7 +435,7 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/x.ts", files["/home/projects/TS/p1/src/x.ts"], core.ScriptKindTS, "") _, project := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/x.ts") program := project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/x.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/x.ts"))), 0) filesCopy := maps.Clone(files) delete(filesCopy, "/home/projects/TS/p1/src/index.ts") @@ -448,7 +448,7 @@ func TestService(t *testing.T) { })) program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/x.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/x.ts"))), 1) }) t.Run("create explicitly included file", func(t *testing.T) { @@ -468,7 +468,7 @@ func TestService(t *testing.T) { program := project.GetProgram() // Initially should have an error because y.ts is missing - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) // Missing location should be watched assert.DeepEqual(t, host.ClientMock.WatchFilesCalls()[0].Watchers, []*lsproto.FileSystemWatcher{ @@ -505,7 +505,7 @@ func TestService(t *testing.T) { // Error should be resolved program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) assert.Check(t, program.GetSourceFile("/home/projects/TS/p1/src/y.ts") != nil) }) @@ -526,7 +526,7 @@ func TestService(t *testing.T) { program := project.GetProgram() // Initially should have an error because z.ts is missing - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) // Missing location should be watched assert.Check(t, slices.ContainsFunc(host.ClientMock.WatchFilesCalls()[1].Watchers, func(w *lsproto.FileSystemWatcher) bool { @@ -546,7 +546,7 @@ func TestService(t *testing.T) { // Error should be resolved and the new file should be included in the program program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) assert.Check(t, program.GetSourceFile("/home/projects/TS/p1/src/z.ts") != nil) }) @@ -567,7 +567,7 @@ func TestService(t *testing.T) { program := project.GetProgram() // Initially should have an error because declaration for 'a' is missing - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 1) // Add a new file through wildcard watch filesCopy := maps.Clone(files) @@ -582,7 +582,7 @@ func TestService(t *testing.T) { // Error should be resolved and the new file should be included in the program program = project.GetProgram() - assert.Equal(t, len(program.GetSemanticDiagnostics(t.Context(), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) + assert.Equal(t, len(program.GetSemanticDiagnostics(projecttestutil.WithRequestID(t.Context()), program.GetSourceFile("/home/projects/TS/p1/src/index.ts"))), 0) assert.Check(t, program.GetSourceFile("/home/projects/TS/p1/src/a.ts") != nil) }) }) diff --git a/internal/testutil/lstestutil/lstestutil.go b/internal/testutil/lstestutil/lstestutil.go index 7929ab4972..e4fe1d5f54 100644 --- a/internal/testutil/lstestutil/lstestutil.go +++ b/internal/testutil/lstestutil/lstestutil.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/microsoft/typescript-go/internal/core" + "github.com/microsoft/typescript-go/internal/ls" "github.com/microsoft/typescript-go/internal/lsp/lsproto" ) @@ -88,6 +89,18 @@ type TestFileInfo struct { // for FourSlashFile Content string } +// FileName implements ls.Script. +func (t *TestFileInfo) FileName() string { + return t.Filename +} + +// Text implements ls.Script. +func (t *TestFileInfo) Text() string { + return t.Content +} + +var _ ls.Script = (*TestFileInfo)(nil) + func parseFileContent(content string, filename string, markerMap map[string]*Marker, markers *[]*Marker) *TestFileInfo { // !!! chompLeadingSpace // !!! validate characters in markers @@ -100,7 +113,7 @@ func parseFileContent(content string, filename string, markerMap map[string]*Mar /// The total number of metacharacters removed from the file (so far) difference := 0 - /// Current position data + /// One-based current position data line := 1 column := 1 @@ -127,7 +140,7 @@ func parseFileContent(content string, filename string, markerMap map[string]*Mar position: (i - 1) - difference, sourcePosition: i - 1, sourceLine: line, - sourceColumn: column, + sourceColumn: column - 1, } } if previousCharacter == '*' && currentCharacter == '/' { @@ -159,10 +172,22 @@ func parseFileContent(content string, filename string, markerMap map[string]*Mar // Add the remaining text flush(-1) - return &TestFileInfo{ - Content: output, + // Set LS positions for markers + lineMap := ls.ComputeLineStarts(output) + converters := ls.NewConverters(lsproto.PositionEncodingKindUTF8, func(_ string) *ls.LineMap { + return lineMap + }) + + testFileInfo := &TestFileInfo{ Filename: filename, + Content: output, } + + for _, marker := range *markers { + marker.LSPosition = converters.PositionToLineAndCharacter(testFileInfo, core.TextPos(marker.Position)) + } + + return testFileInfo } func recordMarker( @@ -176,11 +201,7 @@ func recordMarker( marker := &Marker{ Filename: filename, Position: location.position, - LSPosition: lsproto.Position{ - Line: uint32(location.sourceLine - 1), - Character: uint32(location.sourceColumn - 1), - }, - Name: name, + Name: name, } // Verify markers for uniqueness if _, ok := markerMap[name]; ok { diff --git a/internal/testutil/projecttestutil/projecttestutil.go b/internal/testutil/projecttestutil/projecttestutil.go index ea70b913db..ed23751f49 100644 --- a/internal/testutil/projecttestutil/projecttestutil.go +++ b/internal/testutil/projecttestutil/projecttestutil.go @@ -1,12 +1,14 @@ package projecttestutil import ( + "context" "fmt" "io" "strings" "sync" "github.com/microsoft/typescript-go/internal/bundled" + "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/project" "github.com/microsoft/typescript-go/internal/vfs" "github.com/microsoft/typescript-go/internal/vfs/vfstest" @@ -80,3 +82,7 @@ func newProjectServiceHost(files map[string]string) *ProjectServiceHost { host.logger = project.NewLogger([]io.Writer{&host.output}, "", project.LogLevelVerbose) return host } + +func WithRequestID(ctx context.Context) context.Context { + return core.WithRequestID(ctx, "0") +} From 71d03be5c0babf465d927b257870ddac455ea3fb Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Fri, 16 May 2025 09:14:31 -0700 Subject: [PATCH 08/12] Add context.TODO with comment in sketchy place --- internal/compiler/emitHost.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/compiler/emitHost.go b/internal/compiler/emitHost.go index 3cd269a11d..7afff9d8ce 100644 --- a/internal/compiler/emitHost.go +++ b/internal/compiler/emitHost.go @@ -34,7 +34,10 @@ func (host *emitHost) WriteFile(fileName string, text string, writeByteOrderMark } func (host *emitHost) GetEmitResolver(file *ast.SourceFile, skipDiagnostics bool) printer.EmitResolver { - checker, done := host.program.GetTypeCheckerForFile(context.Background(), file) + // The context and done function don't matter in tsc, currently the only caller of this function. + // But if this ever gets used by LSP code, we'll need to thread the context properly and pass the + // done function to the caller to ensure resources are cleaned up at the end of the request. + checker, done := host.program.GetTypeCheckerForFile(context.TODO(), file) defer done() return checker.GetEmitResolver(file, skipDiagnostics) } From 450ad860a77746fbfb9f9747603c4b99b6a70d3b Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Fri, 16 May 2025 09:35:00 -0700 Subject: [PATCH 09/12] Format --- internal/lsp/server.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 13c644240e..1c369863c6 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -137,7 +137,6 @@ func (s *Server) WatchFiles(ctx context.Context, watchers []*lsproto.FileSystemW }, }, }) - if err != nil { return "", fmt.Errorf("failed to register file watcher: %w", err) } @@ -159,7 +158,6 @@ func (s *Server) UnwatchFiles(ctx context.Context, handle project.WatcherHandle) }, }, }) - if err != nil { return fmt.Errorf("failed to unregister file watcher: %w", err) } @@ -571,7 +569,6 @@ func (s *Server) handleCompletion(ctx context.Context, req *lsproto.RequestMessa params.Context, s.initializeParams.Capabilities.TextDocument.Completion, &ls.UserPreferences{}) - if err != nil { return err } From e89cecd1a7bd48bbbf4f1cfac1ae6dc87ec48fc4 Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Fri, 16 May 2025 09:45:28 -0700 Subject: [PATCH 10/12] Lint --- internal/lsp/server.go | 10 +++++----- internal/project/service.go | 3 ++- internal/project/service_test.go | 9 ++++++--- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 1c369863c6..2c9ff587a6 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -194,7 +194,7 @@ func (s *Server) readLoop(ctx context.Context) error { for { msg, err := s.read() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } if errors.Is(err, lsproto.ErrInvalidRequest) { @@ -266,10 +266,10 @@ func (s *Server) dispatchLoop(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() case req := <-s.requestQueue: - ctx := ctx + requestCtx := ctx if req.ID != nil { var cancel context.CancelFunc - ctx, cancel = context.WithCancel(core.WithRequestID(ctx, req.ID.String())) + requestCtx, cancel = context.WithCancel(core.WithRequestID(requestCtx, req.ID.String())) s.pendingClientRequestsMu.Lock() s.pendingClientRequests[*req.ID] = pendingClientRequest{ req: req, @@ -279,8 +279,8 @@ func (s *Server) dispatchLoop(ctx context.Context) error { } handle := func() { - if err := s.handleRequestOrNotification(ctx, req); err != nil { - if err == io.EOF { + if err := s.handleRequestOrNotification(requestCtx, req); err != nil { + if errors.Is(err, io.EOF) { lspExit() } else { s.sendError(req.ID, err) diff --git a/internal/project/service.go b/internal/project/service.go index c2b75c081e..2ec6068710 100644 --- a/internal/project/service.go +++ b/internal/project/service.go @@ -2,6 +2,7 @@ package project import ( "context" + "errors" "fmt" "strings" "sync" @@ -196,7 +197,7 @@ func (s *Service) ChangeFile(document lsproto.VersionedTextDocumentIdentifier, c NewText: wholeChange.Text, } } else { - return fmt.Errorf("invalid change type") + return errors.New("invalid change type") } } diff --git a/internal/project/service_test.go b/internal/project/service_test.go index af403ad93b..748b5189af 100644 --- a/internal/project/service_test.go +++ b/internal/project/service_test.go @@ -93,7 +93,7 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/x.ts", defaultFiles["/home/projects/TS/p1/src/x.ts"], core.ScriptKindTS, "") info, proj := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/x.ts") programBefore := proj.GetProgram() - service.ChangeFile( + err := service.ChangeFile( lsproto.VersionedTextDocumentIdentifier{ TextDocumentIdentifier: lsproto.TextDocumentIdentifier{ Uri: "file:///home/projects/TS/p1/src/x.ts", @@ -118,6 +118,7 @@ func TestService(t *testing.T) { }, }, ) + assert.NilError(t, err) assert.Equal(t, info.Text(), "export const x = 2;") assert.Equal(t, proj.CurrentProgram(), programBefore) assert.Equal(t, programBefore.GetSourceFile("/home/projects/TS/p1/src/x.ts").Text(), "export const x = 1;") @@ -131,7 +132,7 @@ func TestService(t *testing.T) { _, proj := service.EnsureDefaultProjectForFile("/home/projects/TS/p1/src/x.ts") programBefore := proj.GetProgram() indexFileBefore := programBefore.GetSourceFile("/home/projects/TS/p1/src/index.ts") - service.ChangeFile( + err := service.ChangeFile( lsproto.VersionedTextDocumentIdentifier{ TextDocumentIdentifier: lsproto.TextDocumentIdentifier{ Uri: "file:///home/projects/TS/p1/src/x.ts", @@ -156,6 +157,7 @@ func TestService(t *testing.T) { }, }, ) + assert.NilError(t, err) assert.Equal(t, proj.GetProgram().GetSourceFile("/home/projects/TS/p1/src/index.ts"), indexFileBefore) }) @@ -167,7 +169,7 @@ func TestService(t *testing.T) { service.OpenFile("/home/projects/TS/p1/src/index.ts", files["/home/projects/TS/p1/src/index.ts"], core.ScriptKindTS, "") assert.Check(t, service.GetScriptInfo("/home/projects/TS/p1/y.ts") == nil) - service.ChangeFile( + err := service.ChangeFile( lsproto.VersionedTextDocumentIdentifier{ TextDocumentIdentifier: lsproto.TextDocumentIdentifier{ Uri: "file:///home/projects/TS/p1/src/index.ts", @@ -192,6 +194,7 @@ func TestService(t *testing.T) { }, }, ) + assert.NilError(t, err) service.EnsureDefaultProjectForFile("/home/projects/TS/p1/y.ts") }) }) From 42c4fa43f2944d2f4fba3cdfbea956c68e6ab84d Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Fri, 16 May 2025 09:50:55 -0700 Subject: [PATCH 11/12] go mod tidy --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 067f388aa4..6be49ef853 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874 github.com/google/go-cmp v0.7.0 github.com/pkg/diff v0.0.0-20241224192749-4e6772a4315c + golang.org/x/sync v0.11.0 golang.org/x/sys v0.31.0 gotest.tools/v3 v3.5.2 ) @@ -14,7 +15,6 @@ require ( require ( github.com/matryer/moq v0.5.3 // indirect golang.org/x/mod v0.23.0 // indirect - golang.org/x/sync v0.11.0 // indirect golang.org/x/tools v0.30.0 // indirect ) From 48898b0fee0eb3ef62af882eade0cf58b5bb4832 Mon Sep 17 00:00:00 2001 From: Andrew Branch Date: Fri, 16 May 2025 13:07:50 -0700 Subject: [PATCH 12/12] Log checker stats, fix double program update --- internal/project/checkerpool.go | 19 +++++++++- internal/project/project.go | 65 ++++++++++++++++----------------- 2 files changed, 50 insertions(+), 34 deletions(-) diff --git a/internal/project/checkerpool.go b/internal/project/checkerpool.go index c6f9ce4041..2d1f4edd23 100644 --- a/internal/project/checkerpool.go +++ b/internal/project/checkerpool.go @@ -2,6 +2,7 @@ package project import ( "context" + "fmt" "iter" "sync" @@ -22,17 +23,19 @@ type checkerPool struct { inUse map[*checker.Checker]bool fileAssociations map[*ast.SourceFile]int requestAssociations map[string]int + log func(msg string) } var _ compiler.CheckerPool = (*checkerPool)(nil) -func newCheckerPool(maxCheckers int, program *compiler.Program) *checkerPool { +func newCheckerPool(maxCheckers int, program *compiler.Program, log func(msg string)) *checkerPool { pool := &checkerPool{ program: program, maxCheckers: maxCheckers, checkers: make([]*checker.Checker, maxCheckers), inUse: make(map[*checker.Checker]bool), requestAssociations: make(map[string]int), + log: log, } pool.cond = sync.NewCond(&pool.mu) @@ -154,6 +157,7 @@ func (p *checkerPool) getImmediatelyAvailableChecker() (*checker.Checker, int) { } func (p *checkerPool) waitForAvailableChecker() (*checker.Checker, int) { + p.log("checkerpool: Waiting for an available checker") for { p.cond.Wait() checker, index := p.getImmediatelyAvailableChecker() @@ -171,6 +175,7 @@ func (p *checkerPool) createRelease(requestId string, index int, checker *checke delete(p.requestAssociations, requestId) if checker.WasCanceled() { // Canceled checkers must be disposed + p.log(fmt.Sprintf("checkerpool: Checker for request %s was canceled, disposing it", requestId)) p.checkers[index] = nil delete(p.inUse, checker) } else { @@ -213,4 +218,16 @@ func (p *checkerPool) isRequestCheckerInUse(requestID string) bool { return false } +func (p *checkerPool) size() int { + p.mu.Lock() + defer p.mu.Unlock() + size := 0 + for _, checker := range p.checkers { + if checker != nil { + size++ + } + } + return size +} + func noop() {} diff --git a/internal/project/project.go b/internal/project/project.go index 982c28f4ec..d5ff54a7e2 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -84,13 +84,15 @@ type ProjectHost interface { Client() Client } +var _ compiler.CompilerHost = (*Project)(nil) + type Project struct { host ProjectHost - mu sync.Mutex name string kind Kind + dirtyStateMu sync.Mutex initialLoadPending bool dirty bool version int @@ -111,6 +113,7 @@ type Project struct { rootFileNames *collections.OrderedMap[tspath.Path, string] compilerOptions *core.CompilerOptions parsedCommandLine *tsoptions.ParsedCommandLine + programMu sync.Mutex program *compiler.Program checkerPool *checkerPool @@ -165,37 +168,30 @@ func NewProject(name string, kind Kind, currentDirectory string, host ProjectHos return project } -// FS implements LanguageServiceHost. +// FS implements compiler.CompilerHost. func (p *Project) FS() vfs.FS { return p.host.FS() } -// DefaultLibraryPath implements LanguageServiceHost. +// DefaultLibraryPath implements compiler.CompilerHost. func (p *Project) DefaultLibraryPath() string { return p.host.DefaultLibraryPath() } -// GetCompilerOptions implements LanguageServiceHost. -func (p *Project) GetCompilerOptions() *core.CompilerOptions { - return p.compilerOptions -} - -// GetCurrentDirectory implements LanguageServiceHost. +// GetCurrentDirectory implements compiler.CompilerHost. func (p *Project) GetCurrentDirectory() string { return p.currentDirectory } -// GetProjectVersion implements LanguageServiceHost. -func (p *Project) GetProjectVersion() int { - return p.version -} - -// GetRootFileNames implements LanguageServiceHost. func (p *Project) GetRootFileNames() []string { return slices.Collect(p.rootFileNames.Values()) } -// GetSourceFile implements LanguageServiceHost. +func (p *Project) GetCompilerOptions() *core.CompilerOptions { + return p.compilerOptions +} + +// GetSourceFile implements compiler.CompilerHost. func (p *Project) GetSourceFile(fileName string, path tspath.Path, languageVersion core.ScriptTarget) *ast.SourceFile { scriptKind := p.getScriptKind(fileName) if scriptInfo := p.getOrCreateScriptInfoAndAttachToProject(fileName, scriptKind); scriptInfo != nil { @@ -207,37 +203,32 @@ func (p *Project) GetSourceFile(fileName string, path tspath.Path, languageVersi oldSourceFile = p.program.GetSourceFileByPath(scriptInfo.path) oldCompilerOptions = p.program.GetCompilerOptions() } - return p.host.DocumentRegistry().AcquireDocument(scriptInfo, p.GetCompilerOptions(), oldSourceFile, oldCompilerOptions) + return p.host.DocumentRegistry().AcquireDocument(scriptInfo, p.compilerOptions, oldSourceFile, oldCompilerOptions) } return nil } -// GetProgram implements LanguageServiceHost. Updates the program if needed. +// Updates the program if needed. func (p *Project) GetProgram() *compiler.Program { p.updateIfDirty() return p.program } -// NewLine implements LanguageServiceHost. +// NewLine implements compiler.CompilerHost. func (p *Project) NewLine() string { return p.host.NewLine() } -// Trace implements LanguageServiceHost. +// Trace implements compiler.CompilerHost. func (p *Project) Trace(msg string) { p.log(msg) } -// GetDefaultLibraryPath implements ls.Host. +// GetDefaultLibraryPath implements compiler.CompilerHost. func (p *Project) GetDefaultLibraryPath() string { return p.host.DefaultLibraryPath() } -// GetPositionEncoding implements ls.Host. -func (p *Project) GetPositionEncoding() lsproto.PositionEncodingKind { - return p.host.PositionEncoding() -} - func (p *Project) Name() string { return p.name } @@ -385,8 +376,8 @@ func (p *Project) markFileAsDirty(path tspath.Path) { } func (p *Project) markAsDirty() { - p.mu.Lock() - defer p.mu.Unlock() + p.dirtyStateMu.Lock() + defer p.dirtyStateMu.Unlock() if !p.dirty { p.dirty = true p.version++ @@ -400,8 +391,8 @@ func (p *Project) updateIfDirty() bool { } func (p *Project) onFileAddedOrRemoved(isSymlink bool) { - p.mu.Lock() - defer p.mu.Unlock() + p.dirtyStateMu.Lock() + defer p.dirtyStateMu.Unlock() p.hasAddedOrRemovedFiles = true if isSymlink { p.hasAddedOrRemovedSymlinks = true @@ -413,7 +404,12 @@ func (p *Project) onFileAddedOrRemoved(isSymlink bool) { // opposite of the return value in Strada, which was frequently inverted, // as in `updateProjectIfDirty()`. func (p *Project) updateGraph() bool { - // !!! + p.programMu.Lock() + defer p.programMu.Unlock() + if !p.dirty { + return false + } + p.log("Starting updateGraph: Project: " + p.name) oldProgram := p.program hasAddedOrRemovedFiles := p.hasAddedOrRemovedFiles @@ -459,14 +455,17 @@ func (p *Project) updateGraph() bool { func (p *Project) updateProgram() { rootFileNames := p.GetRootFileNames() - compilerOptions := p.GetCompilerOptions() + compilerOptions := p.compilerOptions + if p.checkerPool != nil { + p.logf("Program %d used %d checker(s)", p.version, p.checkerPool.size()) + } p.program = compiler.NewProgram(compiler.ProgramOptions{ RootFiles: rootFileNames, Host: p, Options: compilerOptions, CreateCheckerPool: func(program *compiler.Program) compiler.CheckerPool { - p.checkerPool = newCheckerPool(4, program) + p.checkerPool = newCheckerPool(4, program, p.log) return p.checkerPool }, })