Skip to content

Commit 2468de2

Browse files
feat(langchain): Support BaseCallbackManager
While implementing #4479, I noticed that our Langchain integration lacks support for the `local_callbacks` having type `BaseCallbackManager`, which according to the type hint is possible. This change adds support for this case.
1 parent 901b579 commit 2468de2

File tree

2 files changed

+172
-16
lines changed

2 files changed

+172
-16
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from langchain_core.callbacks import (
2323
manager,
2424
BaseCallbackHandler,
25+
BaseCallbackManager,
2526
Callbacks,
2627
)
2728
from langchain_core.agents import AgentAction, AgentFinish
@@ -436,11 +437,42 @@ def new_configure(
436437
**kwargs,
437438
)
438439

439-
callbacks_list = local_callbacks or []
440+
# Lambda for lazy initialization of the SentryLangchainCallback
441+
sentry_handler_factory = lambda: SentryLangchainCallback(
442+
integration.max_spans,
443+
integration.include_prompts,
444+
integration.tiktoken_encoding_name,
445+
)
446+
447+
local_callbacks = local_callbacks or []
448+
449+
# Handle each possible type of local_callbacks. For each type, we
450+
# extract the list of callbacks to check for SentryLangchainCallback,
451+
# and define a function that would add the SentryLangchainCallback
452+
# to the existing callbacks list.
453+
if isinstance(local_callbacks, BaseCallbackManager):
454+
callbacks_list = local_callbacks.handlers
455+
456+
# For BaseCallbackManager, we want to copy the manager and add the
457+
# SentryLangchainCallback to the copy.
458+
def local_callbacks_with_sentry():
459+
new_manager = local_callbacks.copy()
460+
new_manager.handlers = [*new_manager.handlers, sentry_handler_factory()]
461+
return new_manager
462+
463+
elif isinstance(local_callbacks, BaseCallbackHandler):
464+
callbacks_list = [local_callbacks]
440465

441-
if isinstance(callbacks_list, BaseCallbackHandler):
442-
callbacks_list = [callbacks_list]
443-
elif not isinstance(callbacks_list, list):
466+
def local_callbacks_with_sentry():
467+
return [local_callbacks, sentry_handler_factory()]
468+
469+
elif isinstance(local_callbacks, list):
470+
callbacks_list = local_callbacks
471+
472+
def local_callbacks_with_sentry():
473+
return [*local_callbacks, sentry_handler_factory()]
474+
475+
else:
444476
logger.debug("Unknown callback type: %s", callbacks_list)
445477
# Just proceed with original function call
446478
return f(
@@ -452,20 +484,12 @@ def new_configure(
452484
)
453485

454486
if not any(isinstance(cb, SentryLangchainCallback) for cb in callbacks_list):
455-
# Avoid mutating the existing callbacks list
456-
callbacks_list = [
457-
*callbacks_list,
458-
SentryLangchainCallback(
459-
integration.max_spans,
460-
integration.include_prompts,
461-
integration.tiktoken_encoding_name,
462-
),
463-
]
487+
local_callbacks = local_callbacks_with_sentry()
464488

465489
return f(
466490
callback_manager_cls,
467491
inheritable_callbacks,
468-
callbacks_list,
492+
local_callbacks,
469493
*args,
470494
**kwargs,
471495
)

tests/integrations/langchain/test_langchain.py

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional, Any, Iterator
2+
from unittest import mock
23
from unittest.mock import Mock
34

45
import pytest
@@ -12,12 +13,15 @@
1213
# Langchain < 0.2
1314
from langchain_community.chat_models import ChatOpenAI
1415

15-
from langchain_core.callbacks import CallbackManagerForLLMRun
16+
from langchain_core.callbacks import BaseCallbackManager, CallbackManagerForLLMRun
1617
from langchain_core.messages import BaseMessage, AIMessageChunk
1718
from langchain_core.outputs import ChatGenerationChunk
1819

1920
from sentry_sdk import start_transaction
20-
from sentry_sdk.integrations.langchain import LangchainIntegration
21+
from sentry_sdk.integrations.langchain import (
22+
LangchainIntegration,
23+
SentryLangchainCallback,
24+
)
2125
from langchain.agents import tool, AgentExecutor, create_openai_tools_agent
2226
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
2327

@@ -342,3 +346,131 @@ def test_span_origin(sentry_init, capture_events):
342346
assert event["contexts"]["trace"]["origin"] == "manual"
343347
for span in event["spans"]:
344348
assert span["origin"] == "auto.ai.langchain"
349+
350+
351+
def test_langchain_callback_manager(sentry_init):
352+
sentry_init(
353+
integrations=[LangchainIntegration()],
354+
traces_sample_rate=1.0,
355+
)
356+
local_manager = BaseCallbackManager(handlers=[])
357+
358+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
359+
mock_configure = mock_manager_module._configure
360+
361+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
362+
LangchainIntegration.setup_once()
363+
364+
callback_manager_cls = Mock()
365+
366+
mock_manager_module._configure(
367+
callback_manager_cls, local_callbacks=local_manager
368+
)
369+
370+
assert mock_configure.call_count == 1
371+
372+
call_args = mock_configure.call_args
373+
assert call_args.args[0] is callback_manager_cls
374+
375+
passed_manager = call_args.args[2]
376+
assert passed_manager is not local_manager
377+
assert local_manager.handlers == []
378+
379+
[handler] = passed_manager.handlers
380+
assert isinstance(handler, SentryLangchainCallback)
381+
382+
383+
def test_langchain_callback_manager_with_sentry_callback(sentry_init):
384+
sentry_init(
385+
integrations=[LangchainIntegration()],
386+
traces_sample_rate=1.0,
387+
)
388+
sentry_callback = SentryLangchainCallback(0, False)
389+
local_manager = BaseCallbackManager(handlers=[sentry_callback])
390+
391+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
392+
mock_configure = mock_manager_module._configure
393+
394+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
395+
LangchainIntegration.setup_once()
396+
397+
callback_manager_cls = Mock()
398+
399+
mock_manager_module._configure(
400+
callback_manager_cls, local_callbacks=local_manager
401+
)
402+
403+
assert mock_configure.call_count == 1
404+
405+
call_args = mock_configure.call_args
406+
assert call_args.args[0] is callback_manager_cls
407+
408+
passed_manager = call_args.args[2]
409+
assert passed_manager is local_manager
410+
411+
[handler] = passed_manager.handlers
412+
assert handler is sentry_callback
413+
414+
415+
def test_langchain_callback_list(sentry_init):
416+
sentry_init(
417+
integrations=[LangchainIntegration()],
418+
traces_sample_rate=1.0,
419+
)
420+
local_callbacks = []
421+
422+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
423+
mock_configure = mock_manager_module._configure
424+
425+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
426+
LangchainIntegration.setup_once()
427+
428+
callback_manager_cls = Mock()
429+
430+
mock_manager_module._configure(
431+
callback_manager_cls, local_callbacks=local_callbacks
432+
)
433+
434+
assert mock_configure.call_count == 1
435+
436+
call_args = mock_configure.call_args
437+
assert call_args.args[0] is callback_manager_cls
438+
439+
passed_callbacks = call_args.args[2]
440+
assert passed_callbacks is not local_callbacks
441+
assert local_callbacks == []
442+
443+
[handler] = passed_callbacks
444+
assert isinstance(handler, SentryLangchainCallback)
445+
446+
447+
def test_langchain_callback_list_existing_callback(sentry_init):
448+
sentry_init(
449+
integrations=[LangchainIntegration()],
450+
traces_sample_rate=1.0,
451+
)
452+
sentry_callback = SentryLangchainCallback(0, False)
453+
local_callbacks = [sentry_callback]
454+
455+
with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module:
456+
mock_configure = mock_manager_module._configure
457+
458+
# Explicitly re-run setup_once, so that mock_manager_module._configure gets patched
459+
LangchainIntegration.setup_once()
460+
461+
callback_manager_cls = Mock()
462+
463+
mock_manager_module._configure(
464+
callback_manager_cls, local_callbacks=local_callbacks
465+
)
466+
467+
assert mock_configure.call_count == 1
468+
469+
call_args = mock_configure.call_args
470+
assert call_args.args[0] is callback_manager_cls
471+
472+
passed_callbacks = call_args.args[2]
473+
assert passed_callbacks is local_callbacks
474+
475+
[handler] = passed_callbacks
476+
assert handler is sentry_callback

0 commit comments

Comments
 (0)