from __future__ import annotations

import copy
import json
import logging
import weakref
from collections.abc import Callable, Generator, Mapping, Sequence
from contextlib import contextmanager
from contextvars import ContextVar, Token
from dataclasses import dataclass
from logging import Logger
from typing import Any, Literal, cast, overload

import anyio
from mcp import LoggingLevel, ServerSession
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel.server import request_ctx
from mcp.shared.context import RequestContext
from mcp.types import (
    CreateMessageResult,
    CreateMessageResultWithTools,
    GetPromptResult,
    ModelPreferences,
    Root,
    SamplingMessage,
    SamplingMessageContentBlock,
    TextContent,
    ToolChoice,
    ToolResultContent,
    ToolUseContent,
)
from mcp.types import Prompt as SDKPrompt
from mcp.types import Resource as SDKResource
from mcp.types import Tool as SDKTool
from pydantic import ValidationError
from pydantic.networks import AnyUrl
from starlette.requests import Request
from typing_extensions import TypeVar

from fastmcp import settings
from fastmcp.server.elicitation import (
    AcceptedElicitation,
    CancelledElicitation,
    DeclinedElicitation,
    handle_elicit_accept,
    parse_elicit_response_type,
)
from fastmcp.server.sampling import SampleStep, SamplingResult, SamplingTool
from fastmcp.server.sampling.run import (
    _parse_model_preferences,
    call_sampling_handler,
    determine_handler_mode,
)
from fastmcp.server.sampling.run import (
    execute_tools as run_sampling_tools,
)
from fastmcp.server.server import FastMCP
from fastmcp.utilities.json_schema import compress_schema
from fastmcp.utilities.logging import _clamp_logger, get_logger
from fastmcp.utilities.types import get_cached_typeadapter

logger: Logger = get_logger(name=__name__)
to_client_logger: Logger = logger.getChild(suffix="to_client")

# Convert all levels of server -> client messages to debug level
# This clamp can be undone at runtime by calling `_unclamp_logger` or calling
# `_clamp_logger` with a different max level.
_clamp_logger(logger=to_client_logger, max_level="DEBUG")


T = TypeVar("T", default=Any)
ResultT = TypeVar("ResultT", default=str)

# Simplified tool choice type - just the mode string instead of the full MCP object
ToolChoiceOption = Literal["auto", "required", "none"]

_current_context: ContextVar[Context | None] = ContextVar("context", default=None)  # type: ignore[assignment]


_flush_lock = anyio.Lock()


@dataclass
class LogData:
    """Data object for passing log arguments to client-side handlers.

    This provides an interface to match the Python standard library logging,
    for compatibility with structured logging.
    """

    msg: str
    extra: Mapping[str, Any] | None = None


_mcp_level_to_python_level = {
    "debug": logging.DEBUG,
    "info": logging.INFO,
    "notice": logging.INFO,
    "warning": logging.WARNING,
    "error": logging.ERROR,
    "critical": logging.CRITICAL,
    "alert": logging.CRITICAL,
    "emergency": logging.CRITICAL,
}


@contextmanager
def set_context(context: Context) -> Generator[Context, None, None]:
    token = _current_context.set(context)
    try:
        yield context
    finally:
        _current_context.reset(token)


@dataclass
class Context:
    """Context object providing access to MCP capabilities.

    This provides a cleaner interface to MCP's RequestContext functionality.
    It gets injected into tool and resource functions that request it via type hints.

    To use context in a tool function, add a parameter with the Context type annotation:

    ```python
    @server.tool
    async def my_tool(x: int, ctx: Context) -> str:
        # Log messages to the client
        await ctx.info(f"Processing {x}")
        await ctx.debug("Debug info")
        await ctx.warning("Warning message")
        await ctx.error("Error message")

        # Report progress
        await ctx.report_progress(50, 100, "Processing")

        # Access resources
        data = await ctx.read_resource("resource://data")

        # Get request info
        request_id = ctx.request_id
        client_id = ctx.client_id

        # Manage state across the request
        ctx.set_state("key", "value")
        value = ctx.get_state("key")

        return str(x)
    ```

    State Management:
    Context objects maintain a state dictionary that can be used to store and share
    data across middleware and tool calls within a request. When a new context
    is created (nested contexts), it inherits a copy of its parent's state, ensuring
    that modifications in child contexts don't affect parent contexts.

    The context parameter name can be anything as long as it's annotated with Context.
    The context is optional - tools that don't need it can omit the parameter.

    """

    def __init__(self, fastmcp: FastMCP):
        self._fastmcp: weakref.ref[FastMCP] = weakref.ref(fastmcp)
        self._tokens: list[Token] = []
        self._notification_queue: set[str] = set()  # Dedupe notifications
        self._state: dict[str, Any] = {}

    @property
    def fastmcp(self) -> FastMCP:
        """Get the FastMCP instance."""
        fastmcp = self._fastmcp()
        if fastmcp is None:
            raise RuntimeError("FastMCP instance is no longer available")
        return fastmcp

    async def __aenter__(self) -> Context:
        """Enter the context manager and set this context as the current context."""
        parent_context = _current_context.get(None)
        if parent_context is not None:
            # Inherit state from parent context
            self._state = copy.deepcopy(parent_context._state)

        # Always set this context and save the token
        token = _current_context.set(self)
        self._tokens.append(token)

        # Set current server for dependency injection (use weakref to avoid reference cycles)
        from fastmcp.server.dependencies import (
            _current_docket,
            _current_server,
            _current_worker,
        )

        self._server_token = _current_server.set(weakref.ref(self.fastmcp))

        # Set docket/worker from server instance for this request's context.
        # This ensures ContextVars work even in environments (like Lambda) where
        # lifespan ContextVars don't propagate to request handlers.
        server = self.fastmcp
        if server._docket is not None:
            self._docket_token = _current_docket.set(server._docket)

        if server._worker is not None:
            self._worker_token = _current_worker.set(server._worker)

        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
        """Exit the context manager and reset the most recent token."""
        # Flush any remaining notifications before exiting
        await self._flush_notifications()

        # Reset server/docket/worker tokens
        from fastmcp.server.dependencies import (
            _current_docket,
            _current_server,
            _current_worker,
        )

        if hasattr(self, "_worker_token"):
            _current_worker.reset(self._worker_token)
            delattr(self, "_worker_token")
        if hasattr(self, "_docket_token"):
            _current_docket.reset(self._docket_token)
            delattr(self, "_docket_token")
        if hasattr(self, "_server_token"):
            _current_server.reset(self._server_token)
            delattr(self, "_server_token")

        # Reset context token
        if self._tokens:
            token = self._tokens.pop()
            _current_context.reset(token)

    @property
    def request_context(self) -> RequestContext[ServerSession, Any, Request] | None:
        """Access to the underlying request context.

        Returns None when the MCP session has not been established yet.
        Returns the full RequestContext once the MCP session is available.

        For HTTP request access in middleware, use `get_http_request()` from fastmcp.server.dependencies,
        which works whether or not the MCP session is available.

        Example in middleware:
        ```python
        async def on_request(self, context, call_next):
            ctx = context.fastmcp_context
            if ctx.request_context:
                # MCP session available - can access session_id, request_id, etc.
                session_id = ctx.session_id
            else:
                # MCP session not available yet - use HTTP helpers
                from fastmcp.server.dependencies import get_http_request
                request = get_http_request()
            return await call_next(context)
        ```
        """
        try:
            return request_ctx.get()
        except LookupError:
            return None

    async def report_progress(
        self, progress: float, total: float | None = None, message: str | None = None
    ) -> None:
        """Report progress for the current operation.

        Args:
            progress: Current progress value e.g. 24
            total: Optional total value e.g. 100
        """

        progress_token = (
            self.request_context.meta.progressToken
            if self.request_context and self.request_context.meta
            else None
        )

        if progress_token is None:
            return

        await self.session.send_progress_notification(
            progress_token=progress_token,
            progress=progress,
            total=total,
            message=message,
            related_request_id=self.request_id,
        )

    async def list_resources(self) -> list[SDKResource]:
        """List all available resources from the server.

        Returns:
            List of Resource objects available on the server
        """
        return await self.fastmcp._list_resources_mcp()

    async def list_prompts(self) -> list[SDKPrompt]:
        """List all available prompts from the server.

        Returns:
            List of Prompt objects available on the server
        """
        return await self.fastmcp._list_prompts_mcp()

    async def get_prompt(
        self, name: str, arguments: dict[str, Any] | None = None
    ) -> GetPromptResult:
        """Get a prompt by name with optional arguments.

        Args:
            name: The name of the prompt to get
            arguments: Optional arguments to pass to the prompt

        Returns:
            The prompt result
        """
        return await self.fastmcp._get_prompt_mcp(name, arguments)

    async def read_resource(self, uri: str | AnyUrl) -> list[ReadResourceContents]:
        """Read a resource by URI.

        Args:
            uri: Resource URI to read

        Returns:
            The resource content as either text or bytes
        """
        # Context calls don't have task metadata, so always returns list
        return await self.fastmcp._read_resource_mcp(uri)  # type: ignore[return-value]

    async def log(
        self,
        message: str,
        level: LoggingLevel | None = None,
        logger_name: str | None = None,
        extra: Mapping[str, Any] | None = None,
    ) -> None:
        """Send a log message to the client.

        Messages sent to Clients are also logged to the `fastmcp.server.context.to_client` logger with a level of `DEBUG`.

        Args:
            message: Log message
            level: Optional log level. One of "debug", "info", "notice", "warning", "error", "critical",
                "alert", or "emergency". Default is "info".
            logger_name: Optional logger name
            extra: Optional mapping for additional arguments
        """
        data = LogData(msg=message, extra=extra)

        await _log_to_server_and_client(
            data=data,
            session=self.session,
            level=level or "info",
            logger_name=logger_name,
            related_request_id=self.request_id,
        )

    @property
    def client_id(self) -> str | None:
        """Get the client ID if available."""
        return (
            getattr(self.request_context.meta, "client_id", None)
            if self.request_context and self.request_context.meta
            else None
        )

    @property
    def request_id(self) -> str:
        """Get the unique ID for this request.

        Raises RuntimeError if MCP request context is not available.
        """
        if self.request_context is None:
            raise RuntimeError(
                "request_id is not available because the MCP session has not been established yet. "
                "Check `context.request_context` for None before accessing this attribute."
            )
        return str(self.request_context.request_id)

    @property
    def session_id(self) -> str:
        """Get the MCP session ID for ALL transports.

        Returns the session ID that can be used as a key for session-based
        data storage (e.g., Redis) to share data between tool calls within
        the same client session.

        Returns:
            The session ID for StreamableHTTP transports, or a generated ID
            for other transports.

        Raises:
            RuntimeError if MCP request context is not available.

        Example:
            ```python
            @server.tool
            def store_data(data: dict, ctx: Context) -> str:
                session_id = ctx.session_id
                redis_client.set(f"session:{session_id}:data", json.dumps(data))
                return f"Data stored for session {session_id}"
            ```
        """
        request_ctx = self.request_context
        if request_ctx is None:
            raise RuntimeError(
                "session_id is not available because the MCP session has not been established yet. "
                "Check `context.request_context` for None before accessing this attribute."
            )
        session = request_ctx.session

        # Try to get the session ID from the session attributes
        session_id = getattr(session, "_fastmcp_id", None)
        if session_id is not None:
            return session_id

        # Try to get the session ID from the http request headers
        request = request_ctx.request
        if request:
            session_id = request.headers.get("mcp-session-id")

        # Generate a session ID if it doesn't exist.
        if session_id is None:
            from uuid import uuid4

            session_id = str(uuid4())

        # Save the session id to the session attributes
        session._fastmcp_id = session_id  # type: ignore[attr-defined]
        return session_id

    @property
    def session(self) -> ServerSession:
        """Access to the underlying session for advanced usage.

        Raises RuntimeError if MCP request context is not available.
        """
        if self.request_context is None:
            raise RuntimeError(
                "session is not available because the MCP session has not been established yet. "
                "Check `context.request_context` for None before accessing this attribute."
            )
        return self.request_context.session

    # Convenience methods for common log levels
    async def debug(
        self,
        message: str,
        logger_name: str | None = None,
        extra: Mapping[str, Any] | None = None,
    ) -> None:
        """Send a `DEBUG`-level message to the connected MCP Client.

        Messages sent to Clients are also logged to the `fastmcp.server.context.to_client` logger with a level of `DEBUG`."""
        await self.log(
            level="debug",
            message=message,
            logger_name=logger_name,
            extra=extra,
        )

    async def info(
        self,
        message: str,
        logger_name: str | None = None,
        extra: Mapping[str, Any] | None = None,
    ) -> None:
        """Send a `INFO`-level message to the connected MCP Client.

        Messages sent to Clients are also logged to the `fastmcp.server.context.to_client` logger with a level of `DEBUG`."""
        await self.log(
            level="info",
            message=message,
            logger_name=logger_name,
            extra=extra,
        )

    async def warning(
        self,
        message: str,
        logger_name: str | None = None,
        extra: Mapping[str, Any] | None = None,
    ) -> None:
        """Send a `WARNING`-level message to the connected MCP Client.

        Messages sent to Clients are also logged to the `fastmcp.server.context.to_client` logger with a level of `DEBUG`."""
        await self.log(
            level="warning",
            message=message,
            logger_name=logger_name,
            extra=extra,
        )

    async def error(
        self,
        message: str,
        logger_name: str | None = None,
        extra: Mapping[str, Any] | None = None,
    ) -> None:
        """Send a `ERROR`-level message to the connected MCP Client.

        Messages sent to Clients are also logged to the `fastmcp.server.context.to_client` logger with a level of `DEBUG`."""
        await self.log(
            level="error",
            message=message,
            logger_name=logger_name,
            extra=extra,
        )

    async def list_roots(self) -> list[Root]:
        """List the roots available to the server, as indicated by the client."""
        result = await self.session.list_roots()
        return result.roots

    async def send_tool_list_changed(self) -> None:
        """Send a tool list changed notification to the client."""
        await self.session.send_tool_list_changed()

    async def send_resource_list_changed(self) -> None:
        """Send a resource list changed notification to the client."""
        await self.session.send_resource_list_changed()

    async def send_prompt_list_changed(self) -> None:
        """Send a prompt list changed notification to the client."""
        await self.session.send_prompt_list_changed()

    async def close_sse_stream(self) -> None:
        """Close the current response stream to trigger client reconnection.

        When using StreamableHTTP transport with an EventStore configured, this
        method gracefully closes the HTTP connection for the current request.
        The client will automatically reconnect (after `retry_interval` milliseconds)
        and resume receiving events from where it left off via the EventStore.

        This is useful for long-running operations to avoid load balancer timeouts.
        Instead of holding a connection open for minutes, you can periodically close
        and let the client reconnect.

        Example:
            ```python
            @mcp.tool
            async def long_running_task(ctx: Context) -> str:
                for i in range(100):
                    await ctx.report_progress(i, 100)

                    # Close connection every 30 iterations to avoid LB timeouts
                    if i % 30 == 0 and i > 0:
                        await ctx.close_sse_stream()

                    await do_work()
                return "Done"
            ```

        Note:
            This is a no-op (with a debug log) if not using StreamableHTTP
            transport with an EventStore configured.
        """
        if not self.request_context or not self.request_context.close_sse_stream:
            logger.debug(
                "close_sse_stream() called but not applicable "
                "(requires StreamableHTTP transport with event_store)"
            )
            return
        await self.request_context.close_sse_stream()

    async def sample_step(
        self,
        messages: str | Sequence[str | SamplingMessage],
        *,
        system_prompt: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        model_preferences: ModelPreferences | str | list[str] | None = None,
        tools: Sequence[SamplingTool | Callable[..., Any]] | None = None,
        tool_choice: ToolChoiceOption | str | None = None,
        execute_tools: bool = True,
        mask_error_details: bool | None = None,
    ) -> SampleStep:
        """
        Make a single LLM sampling call.

        This is a stateless function that makes exactly one LLM call and optionally
        executes any requested tools. Use this for fine-grained control over the
        sampling loop.

        Args:
            messages: The message(s) to send. Can be a string, list of strings,
                or list of SamplingMessage objects.
            system_prompt: Optional system prompt for the LLM.
            temperature: Optional sampling temperature.
            max_tokens: Maximum tokens to generate. Defaults to 512.
            model_preferences: Optional model preferences.
            tools: Optional list of tools the LLM can use.
            tool_choice: Tool choice mode ("auto", "required", or "none").
            execute_tools: If True (default), execute tool calls and append results
                to history. If False, return immediately with tool_calls available
                in the step for manual execution.
            mask_error_details: If True, mask detailed error messages from tool
                execution. When None (default), uses the global settings value.
                Tools can raise ToolError to bypass masking.

        Returns:
            SampleStep containing:
            - .response: The raw LLM response
            - .history: Messages including input, assistant response, and tool results
            - .is_tool_use: True if the LLM requested tool execution
            - .tool_calls: List of tool calls (if any)
            - .text: The text content (if any)

        Example:
            messages = "Research X"

            while True:
                step = await ctx.sample_step(messages, tools=[search])

                if not step.is_tool_use:
                    print(step.text)
                    break

                # Continue with tool results
                messages = step.history
        """
        # Convert messages to SamplingMessage objects
        current_messages = _prepare_messages(messages)

        # Convert tools to SamplingTools
        sampling_tools = _prepare_tools(tools)
        sdk_tools: list[SDKTool] | None = (
            [t._to_sdk_tool() for t in sampling_tools] if sampling_tools else None
        )
        tool_map: dict[str, SamplingTool] = (
            {t.name: t for t in sampling_tools} if sampling_tools else {}
        )

        # Determine whether to use fallback handler or client
        use_fallback = determine_handler_mode(self, bool(sampling_tools))

        # Build tool choice
        effective_tool_choice: ToolChoice | None = None
        if tool_choice is not None:
            if tool_choice not in ("auto", "required", "none"):
                raise ValueError(
                    f"Invalid tool_choice: {tool_choice!r}. "
                    "Must be 'auto', 'required', or 'none'."
                )
            effective_tool_choice = ToolChoice(
                mode=cast(Literal["auto", "required", "none"], tool_choice)
            )

        # Effective max_tokens
        effective_max_tokens = max_tokens if max_tokens is not None else 512

        # Make the LLM call
        if use_fallback:
            response = await call_sampling_handler(
                self,
                current_messages,
                system_prompt=system_prompt,
                temperature=temperature,
                max_tokens=effective_max_tokens,
                model_preferences=model_preferences,
                sdk_tools=sdk_tools,
                tool_choice=effective_tool_choice,
            )
        else:
            response = await self.session.create_message(
                messages=current_messages,
                system_prompt=system_prompt,
                temperature=temperature,
                max_tokens=effective_max_tokens,
                model_preferences=_parse_model_preferences(model_preferences),
                tools=sdk_tools,
                tool_choice=effective_tool_choice,
                related_request_id=self.request_id,
            )

        # Check if this is a tool use response
        is_tool_use_response = (
            isinstance(response, CreateMessageResultWithTools)
            and response.stopReason == "toolUse"
        )

        # Always include the assistant response in history
        current_messages.append(
            SamplingMessage(role="assistant", content=response.content)
        )

        # If not a tool use, return immediately
        if not is_tool_use_response:
            return SampleStep(response=response, history=current_messages)

        # If not executing tools, return with assistant message but no tool results
        if not execute_tools:
            return SampleStep(response=response, history=current_messages)

        # Execute tools and add results to history
        step_tool_calls = _extract_tool_calls(response)
        if step_tool_calls:
            effective_mask = (
                mask_error_details
                if mask_error_details is not None
                else settings.mask_error_details
            )
            tool_results = await run_sampling_tools(
                step_tool_calls, tool_map, mask_error_details=effective_mask
            )

            if tool_results:
                current_messages.append(
                    SamplingMessage(
                        role="user",
                        content=tool_results,  # type: ignore[arg-type]
                    )
                )

        return SampleStep(response=response, history=current_messages)

    @overload
    async def sample(
        self,
        messages: str | Sequence[str | SamplingMessage],
        *,
        system_prompt: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        model_preferences: ModelPreferences | str | list[str] | None = None,
        tools: Sequence[SamplingTool | Callable[..., Any]] | None = None,
        result_type: type[ResultT],
        mask_error_details: bool | None = None,
    ) -> SamplingResult[ResultT]:
        """Overload: With result_type, returns SamplingResult[ResultT]."""

    @overload
    async def sample(
        self,
        messages: str | Sequence[str | SamplingMessage],
        *,
        system_prompt: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        model_preferences: ModelPreferences | str | list[str] | None = None,
        tools: Sequence[SamplingTool | Callable[..., Any]] | None = None,
        result_type: None = None,
        mask_error_details: bool | None = None,
    ) -> SamplingResult[str]:
        """Overload: Without result_type, returns SamplingResult[str]."""

    async def sample(
        self,
        messages: str | Sequence[str | SamplingMessage],
        *,
        system_prompt: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        model_preferences: ModelPreferences | str | list[str] | None = None,
        tools: Sequence[SamplingTool | Callable[..., Any]] | None = None,
        result_type: type[ResultT] | None = None,
        mask_error_details: bool | None = None,
    ) -> SamplingResult[ResultT] | SamplingResult[str]:
        """
        Send a sampling request to the client and await the response.

        This method runs to completion automatically. When tools are provided,
        it executes a tool loop: if the LLM returns a tool use request, the tools
        are executed and the results are sent back to the LLM. This continues
        until the LLM provides a final text response.

        When result_type is specified, a synthetic `final_response` tool is
        created. The LLM calls this tool to provide the structured response,
        which is validated against the result_type and returned as `.result`.

        For fine-grained control over the sampling loop, use sample_step() instead.

        Args:
            messages: The message(s) to send. Can be a string, list of strings,
                or list of SamplingMessage objects.
            system_prompt: Optional system prompt for the LLM.
            temperature: Optional sampling temperature.
            max_tokens: Maximum tokens to generate. Defaults to 512.
            model_preferences: Optional model preferences.
            tools: Optional list of tools the LLM can use. Accepts plain
                functions or SamplingTools.
            result_type: Optional type for structured output. When specified,
                a synthetic `final_response` tool is created and the LLM's
                response is validated against this type.
            mask_error_details: If True, mask detailed error messages from tool
                execution. When None (default), uses the global settings value.
                Tools can raise ToolError to bypass masking.

        Returns:
            SamplingResult[T] containing:
            - .text: The text representation (raw text or JSON for structured)
            - .result: The typed result (str for text, parsed object for structured)
            - .history: All messages exchanged during sampling
        """
        # Safety limit to prevent infinite loops
        max_iterations = 100

        # Convert tools to SamplingTools
        sampling_tools = _prepare_tools(tools)

        # Handle structured output with result_type
        tool_choice: str | None = None
        if result_type is not None and result_type is not str:
            final_response_tool = _create_final_response_tool(result_type)
            sampling_tools = list(sampling_tools) if sampling_tools else []
            sampling_tools.append(final_response_tool)

            # Always require tool calls when result_type is set - the LLM must
            # eventually call final_response (text responses are not accepted)
            tool_choice = "required"

        # Convert messages for the loop
        current_messages: str | Sequence[str | SamplingMessage] = messages

        for _iteration in range(max_iterations):
            step = await self.sample_step(
                messages=current_messages,
                system_prompt=system_prompt,
                temperature=temperature,
                max_tokens=max_tokens,
                model_preferences=model_preferences,
                tools=sampling_tools,
                tool_choice=tool_choice,
                mask_error_details=mask_error_details,
            )

            # Check for final_response tool call for structured output
            if result_type is not None and result_type is not str and step.is_tool_use:
                for tool_call in step.tool_calls:
                    if tool_call.name == "final_response":
                        # Validate and return the structured result
                        type_adapter = get_cached_typeadapter(result_type)

                        # Unwrap if we wrapped primitives (non-object schemas)
                        input_data = tool_call.input
                        original_schema = compress_schema(
                            type_adapter.json_schema(), prune_titles=True
                        )
                        if (
                            original_schema.get("type") != "object"
                            and isinstance(input_data, dict)
                            and "value" in input_data
                        ):
                            input_data = input_data["value"]

                        try:
                            validated_result = type_adapter.validate_python(input_data)
                            text = json.dumps(
                                type_adapter.dump_python(validated_result, mode="json")
                            )
                            return SamplingResult(
                                text=text,
                                result=validated_result,
                                history=step.history,
                            )
                        except ValidationError as e:
                            # Validation failed - add error as tool result
                            step.history.append(
                                SamplingMessage(
                                    role="user",
                                    content=[
                                        ToolResultContent(
                                            type="tool_result",
                                            toolUseId=tool_call.id,
                                            content=[
                                                TextContent(
                                                    type="text",
                                                    text=(
                                                        f"Validation error: {e}. "
                                                        "Please try again with valid data."
                                                    ),
                                                )
                                            ],
                                            isError=True,
                                        )
                                    ],  # type: ignore[arg-type]
                                )
                            )

            # If not a tool use response, we're done
            if not step.is_tool_use:
                # For structured output, the LLM must use the final_response tool
                if result_type is not None and result_type is not str:
                    raise RuntimeError(
                        f"Expected structured output of type {result_type.__name__}, "
                        "but the LLM returned a text response instead of calling "
                        "the final_response tool."
                    )
                return SamplingResult(
                    text=step.text,
                    result=cast(ResultT, step.text if step.text else ""),
                    history=step.history,
                )

            # Continue with the updated history
            current_messages = step.history

            # After first iteration, reset tool_choice to auto
            tool_choice = None

        raise RuntimeError(f"Sampling exceeded maximum iterations ({max_iterations})")

    @overload
    async def elicit(
        self,
        message: str,
        response_type: None,
    ) -> (
        AcceptedElicitation[dict[str, Any]] | DeclinedElicitation | CancelledElicitation
    ): ...

    """When response_type is None, the accepted elicitation will contain an
    empty dict"""

    @overload
    async def elicit(
        self,
        message: str,
        response_type: type[T],
    ) -> AcceptedElicitation[T] | DeclinedElicitation | CancelledElicitation: ...

    """When response_type is not None, the accepted elicitation will contain the
    response data"""

    @overload
    async def elicit(
        self,
        message: str,
        response_type: list[str],
    ) -> AcceptedElicitation[str] | DeclinedElicitation | CancelledElicitation: ...

    """When response_type is a list of strings, the accepted elicitation will
    contain the selected string response"""

    async def elicit(
        self,
        message: str,
        response_type: type[T] | list[str] | dict[str, dict[str, str]] | None = None,
    ) -> (
        AcceptedElicitation[T]
        | AcceptedElicitation[dict[str, Any]]
        | AcceptedElicitation[str]
        | AcceptedElicitation[list[str]]
        | DeclinedElicitation
        | CancelledElicitation
    ):
        """
        Send an elicitation request to the client and await the response.

        Call this method at any time to request additional information from
        the user through the client. The client must support elicitation,
        or the request will error.

        Note that the MCP protocol only supports simple object schemas with
        primitive types. You can provide a dataclass, TypedDict, or BaseModel to
        comply. If you provide a primitive type, an object schema with a single
        "value" field will be generated for the MCP interaction and
        automatically deconstructed into the primitive type upon response.

        If the response_type is None, the generated schema will be that of an
        empty object in order to comply with the MCP protocol requirements.
        Clients must send an empty object ("{}")in response.

        Args:
            message: A human-readable message explaining what information is needed
            response_type: The type of the response, which should be a primitive
                type or dataclass or BaseModel. If it is a primitive type, an
                object schema with a single "value" field will be generated.
        """
        config = parse_elicit_response_type(response_type)

        result = await self.session.elicit(
            message=message,
            requestedSchema=config.schema,
            related_request_id=self.request_id,
        )

        if result.action == "accept":
            return handle_elicit_accept(config, result.content)
        elif result.action == "decline":
            return DeclinedElicitation()
        elif result.action == "cancel":
            return CancelledElicitation()
        else:
            raise ValueError(f"Unexpected elicitation action: {result.action}")

    def set_state(self, key: str, value: Any) -> None:
        """Set a value in the context state."""
        self._state[key] = value

    def get_state(self, key: str) -> Any:
        """Get a value from the context state. Returns None if the key is not found."""
        return self._state.get(key)

    def _queue_tool_list_changed(self) -> None:
        """Queue a tool list changed notification."""
        self._notification_queue.add("notifications/tools/list_changed")

    def _queue_resource_list_changed(self) -> None:
        """Queue a resource list changed notification."""
        self._notification_queue.add("notifications/resources/list_changed")

    def _queue_prompt_list_changed(self) -> None:
        """Queue a prompt list changed notification."""
        self._notification_queue.add("notifications/prompts/list_changed")

    async def _flush_notifications(self) -> None:
        """Send all queued notifications."""
        async with _flush_lock:
            if not self._notification_queue:
                return

            try:
                if "notifications/tools/list_changed" in self._notification_queue:
                    await self.session.send_tool_list_changed()
                if "notifications/resources/list_changed" in self._notification_queue:
                    await self.session.send_resource_list_changed()
                if "notifications/prompts/list_changed" in self._notification_queue:
                    await self.session.send_prompt_list_changed()
                self._notification_queue.clear()
            except Exception:
                # Don't let notification failures break the request
                pass


async def _log_to_server_and_client(
    data: LogData,
    session: ServerSession,
    level: LoggingLevel,
    logger_name: str | None = None,
    related_request_id: str | None = None,
) -> None:
    """Log a message to the server and client."""

    msg_prefix = f"Sending {level.upper()} to client"

    if logger_name:
        msg_prefix += f" ({logger_name})"

    to_client_logger.log(
        level=_mcp_level_to_python_level[level],
        msg=f"{msg_prefix}: {data.msg}",
        extra=data.extra,
    )

    await session.send_log_message(
        level=level,
        data=data,
        logger=logger_name,
        related_request_id=related_request_id,
    )


def _create_final_response_tool(result_type: type) -> SamplingTool:
    """Create a synthetic 'final_response' tool for structured output.

    This tool is used to capture structured responses from the LLM.
    The tool's schema is derived from the result_type.
    """
    type_adapter = get_cached_typeadapter(result_type)
    schema = type_adapter.json_schema()
    schema = compress_schema(schema, prune_titles=True)

    # Tool parameters must be object-shaped. Wrap primitives in {"value": <schema>}
    if schema.get("type") != "object":
        schema = {
            "type": "object",
            "properties": {"value": schema},
            "required": ["value"],
        }

    # The fn just returns the input as-is (validation happens in the loop)
    def final_response(**kwargs: Any) -> dict[str, Any]:
        return kwargs

    return SamplingTool(
        name="final_response",
        description=(
            "Call this tool to provide your final response. "
            "Use this when you have completed the task and are ready to return the result."
        ),
        parameters=schema,
        fn=final_response,
    )


def _extract_text_from_content(
    content: SamplingMessageContentBlock | list[SamplingMessageContentBlock],
) -> str | None:
    """Extract text from content block(s).

    Returns the text if content is a TextContent or list containing TextContent,
    otherwise returns None.
    """
    if isinstance(content, list):
        for block in content:
            if isinstance(block, TextContent):
                return block.text
        return None
    elif isinstance(content, TextContent):
        return content.text
    return None


def _prepare_messages(
    messages: str | Sequence[str | SamplingMessage],
) -> list[SamplingMessage]:
    """Convert various message formats to a list of SamplingMessage objects."""
    if isinstance(messages, str):
        return [
            SamplingMessage(
                content=TextContent(text=messages, type="text"), role="user"
            )
        ]
    else:
        return [
            SamplingMessage(content=TextContent(text=m, type="text"), role="user")
            if isinstance(m, str)
            else m
            for m in messages
        ]


def _prepare_tools(
    tools: Sequence[SamplingTool | Callable[..., Any]] | None,
) -> list[SamplingTool] | None:
    """Convert tools to SamplingTool objects."""
    if tools is None:
        return None

    sampling_tools: list[SamplingTool] = []
    for t in tools:
        if isinstance(t, SamplingTool):
            sampling_tools.append(t)
        elif callable(t):
            sampling_tools.append(SamplingTool.from_function(t))
        else:
            raise TypeError(f"Expected SamplingTool or callable, got {type(t)}")

    return sampling_tools if sampling_tools else None


def _extract_tool_calls(
    response: CreateMessageResult | CreateMessageResultWithTools,
) -> list[ToolUseContent]:
    """Extract tool calls from a response."""
    content = response.content
    if isinstance(content, list):
        return [c for c in content if isinstance(c, ToolUseContent)]
    elif isinstance(content, ToolUseContent):
        return [content]
    return []
