Skip to content

Infer type of generic class from return type of (optionally awaitable) callable passed to constructor #19143

Open
@DouweM

Description

@DouweM

I've confirmed the following works in pyright and pyrefly - but not in mypy:

from dataclasses import dataclass

from typing_extensions import (
    Awaitable,
    Callable,
    Generic,
    TypeVar,
    assert_type,
)

T = TypeVar("T")


@dataclass
class Agent(Generic[T]):
    output_type: Callable[..., T] | Callable[..., Awaitable[T]]


async def coro() -> bool:
    return True


def func() -> int:
    return 1


# works
assert_type(Agent(func), Agent[int])

# mypy - error: Argument 1 to "Agent" has incompatible type "Callable[[], Coroutine[Any, Any, bool]]"; expected "Callable[..., Never] | Callable[..., Awaitable[Never]]"  [arg-type]
coro_agent = Agent(coro)
# pyright, pyrefly - works
# mypy - error: Expression is of type "Agent[Any]", not "Agent[bool]"
assert_type(coro_agent, Agent[bool])

# works
assert_type(Agent[bool](coro), Agent[bool])

I want T to be inferred as the ultimate return type of the awaitable if an async function is passed rather than a regular one, but I suppose it's ambiguous which side of the union is the best match.

It would be great to see this work in mypy, but I'm also open to suggestions to do this in a less ambiguous way!

Activity

A5rocks

A5rocks commented on May 23, 2025

@A5rocks
Collaborator

As you note this is ambiguous. As a workaround overloads work.

I'm not sure if there's any principled way around this. I have been thinking about a "strict coloring" mode which treats async functions as different types than non-async ones, but that would be stricter than you would like and also not everyone will enable it.

Maybe mypy should special case specifically T | Awaitable[T] since that's the only case where I've seen this.

DouweM

DouweM commented on May 23, 2025

@DouweM
Author

@A5rocks I appreciate the quick response.

Naively, I'd imagine a general rule like "in case of multiple possible matches, pick the most specific one", which may be what pyright is doing (note that I haven't looked at the implementation).

Special casing T | Awaitable[T] would work for me, but I'm curious why we couldn't do that for any T | Foo[T] when given Foo[T]. I'd expect this to work, for example:

from typing import Sequence, TypeVar, assert_type

T = TypeVar("T")


def ensure_sequence[T](x: T | Sequence[T]) -> Sequence[T]:
    if isinstance(x, Sequence):
        return x
    return [x]


assert_type(ensure_sequence(1), Sequence[int])
assert_type(ensure_sequence([1, 2, 3]), Sequence[int])

mypy currently says this:

error: Expression is of type "Sequence[Never]", not "Sequence[int]"  [assert-type]
error: Argument 1 to "ensure_sequence" has incompatible type "list[int]"; expected "Sequence[Never]"  [arg-type]

Note that pyright doesn't like this either, so maybe it is special casing T | Awaitable[T]. It complains about x on line return x, but it does let the assert_type pass:

error: Return type, "Sequence[Unknown]* | Sequence[T@ensure_sequence]", is partially unknown (reportUnknownVariableType)
A5rocks

A5rocks commented on May 23, 2025

@A5rocks
Collaborator

in case of multiple possible matches, pick the most specific one

I imagine this wouldn't do well if there's multiple possible matches with same specificity, or even something like:

class A(Protocol[T]):
  a: T

def f(x: T | A[T]) -> T: ...

(You could imagine A as Awaitable and a: T as def __await__(self) -> T (iirc?) if you like)

However it is an improvement in some cases so if we can isolate those that sounds fine. But also if we're adding special cases I would rather being specific eg only special casing T | Awaitable[T].

Maybe a better method is tracking the number of levels above the typevar and choosing the highest one? Or maybe discarding conflicting constraints in order of the union? Neither sound very performant of course.

DouweM

DouweM commented on May 23, 2025

@DouweM
Author

@A5rocks Good point, a special case sounds reasonable then. Thanks for considering this!

A5rocks

A5rocks commented on May 25, 2025

@A5rocks
Collaborator

And BTW I saw you misinterpreted me in the comment for the PR: __init__ can be overloaded.

DouweM

DouweM commented on May 26, 2025

@DouweM
Author

@A5rocks As don't think that'd work with the real OutputType, with the Callable[..., T | Awaitable[T]] nested a few levels down:

T_co = TypeVar('T_co', covariant=True)
# output_type=Type or output_type=function or output_type=object.method
SimpleOutputType = TypeAliasType(
    'SimpleOutputType', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,)
)
# output_type=ToolOutput(<see above>) or <see above>
SimpleOutputTypeOrMarker = TypeAliasType(
    'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,)
)
# output_type=<see above> or [<see above>, ...]
OutputType = TypeAliasType(
    'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,)
)

That means output_type can be a list of types and regular functions and async functions:

def func(x: str, y: int) -> str:
    return f'{x} {y}'


async def coro(x: int, y: int) -> int:
    return x * y

complex_output_agent = Agent(output_type=[Foo, Bar, func, coro])
assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int])

Note that the use of Sequence[...] with type[T] also has us run into #19142.

To use overloads, I think I'd need to define some new marker class like OutputFunc with overloads for regular functions and async functions. That's definitely an option, but would make the API a bit less clean, so if mypy is planning to fix this issue and the Sequence[type[T]] one, I'd rather keep it like this (which already works with pyright).

Is there another option I'm missing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugmypy got something wrongtopic-inferenceWhen to infer types or require explicit annotations

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @DouweM@A5rocks

        Issue actions

          Infer type of generic class from return type of (optionally awaitable) callable passed to constructor · Issue #19143 · python/mypy