diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 8e77c0a5b..724cbb527 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -231,7 +231,7 @@ describe("OAuth Authorization", () => { ok: false, status: 404, }); - + // Second call (root fallback) succeeds mockFetch.mockResolvedValueOnce({ ok: true, @@ -241,17 +241,17 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com/path/name"); expect(metadata).toEqual(validMetadata); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(2); - + // First call should be path-aware const [firstUrl, firstOptions] = calls[0]; expect(firstUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/path/name"); expect(firstOptions.headers).toEqual({ "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION }); - + // Second call should be root fallback const [secondUrl, secondOptions] = calls[1]; expect(secondUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); @@ -266,7 +266,7 @@ describe("OAuth Authorization", () => { ok: false, status: 404, }); - + // Second call (root fallback) also returns 404 mockFetch.mockResolvedValueOnce({ ok: false, @@ -275,7 +275,7 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com/path/name"); expect(metadata).toBeUndefined(); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(2); }); @@ -289,10 +289,10 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com/"); expect(metadata).toBeUndefined(); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(1); // Should not attempt fallback - + const [url] = calls[0]; expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); }); @@ -306,10 +306,10 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com"); expect(metadata).toBeUndefined(); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(1); // Should not attempt fallback - + const [url] = calls[0]; expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); }); @@ -317,13 +317,13 @@ describe("OAuth Authorization", () => { it("falls back when path-aware discovery encounters CORS error", async () => { // First call (path-aware) fails with TypeError (CORS) mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError("CORS error"))); - + // Retry path-aware without headers (simulating CORS retry) mockFetch.mockResolvedValueOnce({ ok: false, status: 404, }); - + // Second call (root fallback) succeeds mockFetch.mockResolvedValueOnce({ ok: true, @@ -333,10 +333,10 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com/deep/path"); expect(metadata).toEqual(validMetadata); - + const calls = mockFetch.mock.calls; expect(calls.length).toBe(3); - + // Final call should be root fallback const [lastUrl, lastOptions] = calls[2]; expect(lastUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); @@ -1463,5 +1463,211 @@ describe("OAuth Authorization", () => { expect(body.get("grant_type")).toBe("refresh_token"); expect(body.get("refresh_token")).toBe("refresh123"); }); + + describe("delegateAuthorization", () => { + const validMetadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + const validTokens = { + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }; + + // Setup shared mock function for all tests + beforeEach(() => { + // Reset mockFetch implementation + mockFetch.mockReset(); + + // Set up the mockFetch to respond to all necessary API calls + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: false, + status: 404 + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => validMetadata + }); + } else if (urlString.includes("/token")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => validTokens + }); + } + + return Promise.reject(new Error(`Unexpected fetch call: ${urlString}`)); + }); + }); + + it("should use delegateAuthorization when implemented and return AUTHORIZED", async () => { + const mockProvider: OAuthClientProvider = { + redirectUrl: "http://localhost:3000/callback", + clientMetadata: { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client" + }, + clientInformation: () => validClientInfo, + tokens: () => validTokens, + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: () => "test_verifier", + delegateAuthorization: jest.fn().mockResolvedValue("AUTHORIZED") + }; + + const result = await auth(mockProvider, { serverUrl: "https://auth.example.com" }); + + expect(result).toBe("AUTHORIZED"); + expect(mockProvider.delegateAuthorization).toHaveBeenCalledWith( + "https://auth.example.com", + { + metadata: expect.objectContaining(validMetadata), + resource: undefined + } + ); + expect(mockProvider.redirectToAuthorization).not.toHaveBeenCalled(); + }); + + it("should fall back to standard flow when delegateAuthorization returns undefined", async () => { + const mockProvider: OAuthClientProvider = { + redirectUrl: "http://localhost:3000/callback", + clientMetadata: { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client" + }, + clientInformation: () => validClientInfo, + tokens: () => validTokens, + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: () => "test_verifier", + delegateAuthorization: jest.fn().mockResolvedValue(undefined) + }; + + const result = await auth(mockProvider, { serverUrl: "https://auth.example.com" }); + + expect(result).toBe("AUTHORIZED"); + expect(mockProvider.delegateAuthorization).toHaveBeenCalled(); + expect(mockProvider.saveTokens).toHaveBeenCalled(); + }); + + it("should not call delegateAuthorization when processing authorizationCode", async () => { + const mockProvider: OAuthClientProvider = { + redirectUrl: "http://localhost:3000/callback", + clientMetadata: { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client" + }, + clientInformation: () => validClientInfo, + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: () => "test_verifier", + delegateAuthorization: jest.fn() + }; + + await auth(mockProvider, { + serverUrl: "https://auth.example.com", + authorizationCode: "code123" + }); + + expect(mockProvider.delegateAuthorization).not.toHaveBeenCalled(); + expect(mockProvider.saveTokens).toHaveBeenCalled(); + }); + + it("should propagate errors from delegateAuthorization", async () => { + const mockProvider: OAuthClientProvider = { + redirectUrl: "http://localhost:3000/callback", + clientMetadata: { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client" + }, + clientInformation: () => validClientInfo, + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: () => "test_verifier", + delegateAuthorization: jest.fn().mockRejectedValue(new Error("Delegation failed")) + }; + + await expect(auth(mockProvider, { serverUrl: "https://auth.example.com" })) + .rejects.toThrow("Delegation failed"); + }); + + it("should pass both resource and metadata to delegateAuthorization when available", async () => { + // Mock resource metadata to be returned by the fetch + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://api.example.com/", + authorization_servers: ["https://auth.example.com"] + }) + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => validMetadata + }); + } + + return Promise.reject(new Error(`Unexpected fetch call: ${urlString}`)); + }); + + const mockProvider: OAuthClientProvider = { + redirectUrl: "http://localhost:3000/callback", + clientMetadata: { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client" + }, + clientInformation: () => validClientInfo, + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: () => "test_verifier", + delegateAuthorization: jest.fn().mockResolvedValue("AUTHORIZED") + }; + + const result = await auth(mockProvider, { serverUrl: "https://api.example.com" }); + + expect(result).toBe("AUTHORIZED"); + expect(mockProvider.delegateAuthorization).toHaveBeenCalledWith( + "https://auth.example.com", + { + resource: new URL("https://api.example.com/"), + metadata: expect.objectContaining(validMetadata) + } + ); + }); + }); }); }); diff --git a/src/client/auth.ts b/src/client/auth.ts index 376905743..700313d34 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -81,6 +81,32 @@ export interface OAuthClientProvider { * Implementations must verify the returned resource matches the MCP server. */ validateResourceURL?(serverUrl: string | URL, resource?: string): Promise; + + /** + * Optional method that allows the OAuth client to delegate authorization + * to an existing implementation, such as a platform or app-level identity provider. + * + * If this method returns "AUTHORIZED", the standard authorization flow will be bypassed. + * If it returns `undefined`, the SDK will proceed with its default OAuth implementation. + * + * When returning "AUTHORIZED", the implementation must ensure tokens have been saved + * through the provider's saveTokens method, or are accessible via the tokens() method. + * + * This method is useful when the host application already manages OAuth tokens or user sessions + * and does not need the SDK to handle the entire authorization flow directly. + * + * For example, in a mobile app, this could delegate to the native platform authentication, + * or in a browser application, it could use existing tokens from localStorage. + * + * Note: This method will NOT be called when processing an authorization code callback. + * + * @param serverUrl The URL of the authorization server. + * @param options The options for the method + * @param options.resource The protected resource (RFC 8707) to authorize (may be undefined if not available) + * @param options.metadata The OAuth metadata if available (may be undefined if discovery fails) + * @returns "AUTHORIZED" if delegation succeeded and tokens are already available; otherwise `undefined`. + */ + delegateAuthorization?(serverUrl: string | URL, options?: { resource?: URL, metadata?: OAuthMetadata}): "AUTHORIZED" | undefined | Promise<"AUTHORIZED" | undefined>; } export type AuthResult = "AUTHORIZED" | "REDIRECT"; @@ -124,6 +150,15 @@ export async function auth( const metadata = await discoverOAuthMetadata(authorizationServerUrl); + // Delegate the authorization if supported and if not already in the middle of the standard flow + if (provider.delegateAuthorization && authorizationCode === undefined) { + const options = resource || metadata ? { resource, metadata } : undefined; + const result = await provider.delegateAuthorization(authorizationServerUrl, options); + if (result === "AUTHORIZED") { + return "AUTHORIZED"; + } + } + // Handle client registration if needed let clientInformation = await Promise.resolve(provider.clientInformation()); if (!clientInformation) {