"""Wallet service for point transactions."""

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal, ROUND_HALF_UP
from typing import TYPE_CHECKING, Any, Dict, Optional

from .exceptions import (
    InsufficientBalanceError,
    InvalidParamsError,
    InvalidPointTypeError,
    WalletOperationError,
)

if TYPE_CHECKING:
    from django.db.models import Model

    from .repository import WalletRepositoryProtocol


@dataclass(frozen=True)
class TransactionRecord:
    """Represents a wallet transaction record."""

    id: int | None
    wtype: str  # wallet type / point type
    iid: int  # initiator id (user who performed transaction, -100 for system)
    uid: int  # user id
    type: str  # 'c' for credit/add, 'd' for debit/deduct
    amount: Decimal
    balance: Decimal  # balance after transaction
    trans_type: int | None  # transaction type code (integer constant)
    descr: str  # remarks/description
    cdate: str  # creation date/time
    extra_data: Dict[str, Any]  # additional fields


class WalletService:
    """
    Python port of the legacy WalletPoint PHP class.

    The service works with a WalletRepositoryProtocol that only needs to define point types.
    All balance operations and transaction recording are handled automatically.
    """

    # Reserved fields that cannot be used in extra_data
    RESERVED_FIELDS = {
        "id",
        "wtype",
        "iid",
        "uid",
        "type",
        "amount",
        "balance",
        "trans_type",
        "descr",
        "cdate",
    }

    def __init__(self, repository: "WalletRepositoryProtocol"):
        """
        Initialize the wallet service.

        Args:
            repository: WalletRepositoryProtocol instance that defines point types and provides User model
        """
        self.repository = repository

    def add_point(
        self,
        user_id: int,
        point_type: str,
        amount: Decimal,
        remarks: str,
        trans_type: int | None = None,
        params: Optional[Dict[str, Any]] = None,
        iid: int = -100,
    ) -> int:
        """
        Add points to a user wallet balance.

        Args:
            user_id: The user ID
            point_type: The type of wallet to add to (e.g., "credit_balance")
            amount: The amount to add (must be positive)
            remarks: Description/remarks for the transaction
            trans_type: Transaction type code (integer constant, optional)
            params: Extra parameters dict with optional 'data' key for additional fields
            iid: Initiator ID - ID of the user who performed this transaction (-100 for system)

        Returns:
            The transaction ID

        Raises:
            InvalidPointTypeError: If point type doesn't exist
            InvalidParamsError: If params are invalid
            WalletOperationError: If the operation fails
        """
        if params is None:
            params = {}

        self._validate_add_deduct_point(
            "add", user_id, point_type, amount, remarks, trans_type, False, params
        )

        # Get decimal places for this point type
        point_types = self.repository.get_point_types()
        decimal_places = point_types.get(point_type, 2)
        decimal_places = max(0, int(decimal_places))

        # Round amount
        amount_rounded = amount.quantize(Decimal("0.1") ** decimal_places, rounding=ROUND_HALF_UP)

        # Update balance
        try:
            balance_after = self.repository.update_balance(user_id, point_type, amount_rounded, allow_negative=False)
        except Exception as e:
            raise WalletOperationError(f"Error adding user point: {e}") from e

        # Create transaction record (always saved)
        record = TransactionRecord(
            id=None,
            wtype=point_type,
            iid=iid,
            uid=user_id,
            type="c",  # credit/add
            amount=amount_rounded,
            balance=balance_after,
            trans_type=trans_type,
            descr=remarks,
            cdate=self._get_current_datetime(),
            extra_data=params.get("data", {}),
        )
        try:
            transaction_id = self.repository.create_transaction_record(record)
        except Exception as e:
            raise WalletOperationError(f"Error creating transaction record: {e}") from e

        return transaction_id

    def deduct_point(
        self,
        user_id: int,
        point_type: str,
        amount: Decimal,
        remarks: str,
        trans_type: int | None = None,
        allow_negative: bool = False,
        params: Optional[Dict[str, Any]] = None,
        iid: int = -100,
    ) -> int:
        """
        Deduct points from a user wallet balance using atomic SQL operation.

        Uses SQL WHERE clause to prevent race conditions and ensure balance sufficiency.

        Args:
            user_id: The user ID
            point_type: The type of wallet to deduct from
            amount: The amount to deduct (must be positive)
            remarks: Description/remarks for the transaction
            trans_type: Transaction type code (integer constant, optional)
            allow_negative: Whether to allow negative balance (default: False)
            params: Extra parameters dict with optional 'data' key for additional fields
            iid: Initiator ID - ID of the user who performed this transaction (-100 for system)

        Returns:
            The transaction ID

        Raises:
            InsufficientBalanceError: If insufficient balance and allow_negative=False
            InvalidPointTypeError: If point type doesn't exist
            InvalidParamsError: If params are invalid
            WalletOperationError: If the operation fails
        """
        if params is None:
            params = {}

        self._validate_add_deduct_point(
            "deduct", user_id, point_type, amount, remarks, trans_type, allow_negative, params
        )

        # Get decimal places for this point type
        point_types = self.repository.get_point_types()
        decimal_places = point_types.get(point_type, 2)
        decimal_places = max(0, int(decimal_places))

        # Round amount
        amount_rounded = amount.quantize(Decimal("0.1") ** decimal_places, rounding=ROUND_HALF_UP)

        # Atomically deduct using SQL WHERE clause
        # This prevents race conditions and ensures balance sufficiency
        try:
            success, balance_after = self.repository.deduct_balance_atomic(
                user_id, point_type, amount_rounded, allow_negative=allow_negative
            )
        except Exception as e:
            raise WalletOperationError(f"Error deducting user point: {e}") from e

        # Check if deduction was successful (rowCount > 0)
        if not success:
            # Get current balance for error message
            current_balance = self.repository.get_user_balance(user_id, point_type)
            available = float(current_balance)
            requested = float(amount)
            raise InsufficientBalanceError(
                f"Insufficient balance for this transaction. Available: {available:.{decimal_places}f}, "
                f"Requested: {requested:.{decimal_places}f}",
                available=available,
                requested=requested,
            )

        # Create transaction record (always saved)
        record = TransactionRecord(
            id=None,
            wtype=point_type,
            iid=iid,
            uid=user_id,
            type="d",  # debit/deduct
            amount=amount_rounded,
            balance=balance_after,
            trans_type=trans_type,
            descr=remarks,
            cdate=self._get_current_datetime(),
            extra_data=params.get("data", {}),
        )
        try:
            transaction_id = self.repository.create_transaction_record(record)
        except Exception as e:
            raise WalletOperationError(f"Error creating transaction record: {e}") from e

        return transaction_id

    def _validate_user_point_type(self, user_id: int, point_type: str) -> None:
        """
        Validate whether the given point type is valid.

        Raises:
            InvalidPointTypeError: If point type doesn't exist
        """
        valid_types = self.repository.get_point_types()
        if point_type not in valid_types:
            raise InvalidPointTypeError(
                f"No such point ({point_type}) exists!", point_type=point_type
            )

    def _validate_add_deduct_point(
        self,
        add_deduct: str,
        user_id: int,
        point_type: str,
        amount: Decimal,
        remarks: str,
        trans_type: int | None,
        allow_negative: bool,
        params: Dict[str, Any],
    ) -> None:
        """
        Validate parameters for add/deduct point operations.

        Raises:
            InvalidPointTypeError: If point type doesn't exist
            InvalidParamsError: If params are invalid
        """
        self._validate_user_point_type(user_id, point_type)

        method = "add_point" if add_deduct == "add" else "deduct_point"

        if not isinstance(params, dict):
            raise InvalidParamsError(f"Invalid params for {method}(), it must be a dict!", method=method)

        if "data" in params and params["data"]:
            if not isinstance(params["data"], dict):
                raise InvalidParamsError(
                    f'Invalid params["data"] for {method}(), it must be a dict!', method=method
                )

            invalid_fields = []
            for field in params["data"].keys():
                if field in self.RESERVED_FIELDS:
                    invalid_fields.append(field)

            if invalid_fields:
                fields_str = ", ".join(invalid_fields)
                raise InvalidParamsError(
                    f'Invalid params["data"] for {method}(), some of the data parameters are invalid: {fields_str}!',
                    method=method,
                )

    def get_transaction_history(
        self,
        user_id: int | None = None,
        point_type: str | None = None,
        trans_type: int | None = None,
        transaction_type: str | None = None,
        iid: int | None = None,
        start_date: datetime | str | None = None,
        end_date: datetime | str | None = None,
        limit: int | None = None,
        offset: int = 0,
    ) -> list[TransactionRecord]:
        """
        Retrieve wallet transaction history with various filters.
        
        Args:
            user_id: Filter by user ID (uid field)
            point_type: Filter by wallet/point type (wtype field)
            trans_type: Filter by transaction type code (trans_type field, integer constant)
            transaction_type: Filter by transaction type ('c' for credit, 'd' for debit) (type field)
            iid: Filter by initiator ID (iid field) - user who performed the transaction (-100 for system)
            start_date: Filter transactions from this date (inclusive) (cdate field)
            end_date: Filter transactions until this date (inclusive) (cdate field)
            limit: Maximum number of records to return
            offset: Number of records to skip (for pagination)
        
        Returns:
            List of TransactionRecord objects
        
        Example:
            # Get all transactions for a user
            transactions = service.get_transaction_history(user_id=123)
            
            # Get credit transactions for a specific point type
            transactions = service.get_transaction_history(
                user_id=123,
                point_type="credit_balance",
                transaction_type="c",
            )
            
            # Get transactions by transaction type code
            from wallet_utils.transaction_types import WALLET_DEPOSIT
            transactions = service.get_transaction_history(
                user_id=123,
                trans_type=WALLET_DEPOSIT,
            )
            
            # Get transactions in date range
            from datetime import datetime, timedelta
            end = datetime.now()
            start = end - timedelta(days=30)
            transactions = service.get_transaction_history(
                user_id=123,
                start_date=start,
                end_date=end,
                limit=100,
            )
        """
        wallet_model = self.repository.get_wallet_model()
        queryset = wallet_model.objects.all()
        
        # Apply filters
        if user_id is not None:
            queryset = queryset.filter(uid=user_id)
        
        if point_type is not None:
            # Validate point type exists
            valid_types = self.repository.get_point_types()
            if point_type not in valid_types:
                raise InvalidPointTypeError(
                    f"No such point ({point_type}) exists!", point_type=point_type
                )
            queryset = queryset.filter(wtype=point_type)
        
        if trans_type is not None:
            queryset = queryset.filter(trans_type=trans_type)
        
        if transaction_type is not None:
            if transaction_type not in ("c", "d"):
                raise InvalidParamsError(
                    f"Invalid transaction_type: {transaction_type}. Must be 'c' (credit) or 'd' (debit)."
                )
            queryset = queryset.filter(type=transaction_type)
        
        if iid is not None:
            queryset = queryset.filter(iid=iid)
        
        if start_date is not None:
            if isinstance(start_date, str):
                start_date = datetime.fromisoformat(start_date)
            queryset = queryset.filter(cdate__gte=start_date)
        
        if end_date is not None:
            if isinstance(end_date, str):
                end_date = datetime.fromisoformat(end_date)
            queryset = queryset.filter(cdate__lte=end_date)
        
        # Order by creation date (newest first)
        queryset = queryset.order_by("-cdate")
        
        # Apply pagination
        if offset > 0:
            queryset = queryset[offset:]
        if limit is not None:
            queryset = queryset[:limit]
        
        # Convert to TransactionRecord objects
        records = []
        for obj in queryset:
            record = TransactionRecord(
                id=obj.id,
                wtype=obj.wtype,
                iid=obj.iid,
                uid=obj.uid,
                type=obj.type,
                amount=obj.amount,
                balance=obj.balance,
                trans_type=obj.trans_type,
                descr=obj.descr,
                cdate=obj.cdate.isoformat() if hasattr(obj.cdate, "isoformat") else str(obj.cdate),
                extra_data=obj.extra_data or {},
            )
            records.append(record)
        
        return records
    
    def get_transaction_by_id(self, transaction_id: int) -> TransactionRecord | None:
        """
        Get a single transaction by its ID.
        
        Args:
            transaction_id: The transaction ID
        
        Returns:
            TransactionRecord if found, None otherwise
        """
        wallet_model = self.repository.get_wallet_model()
        try:
            obj = wallet_model.objects.get(id=transaction_id)
            return TransactionRecord(
                id=obj.id,
                wtype=obj.wtype,
                iid=obj.iid,
                uid=obj.uid,
                type=obj.type,
                amount=obj.amount,
                balance=obj.balance,
                trans_type=obj.trans_type,
                descr=obj.descr,
                cdate=obj.cdate.isoformat() if hasattr(obj.cdate, "isoformat") else str(obj.cdate),
                extra_data=obj.extra_data or {},
            )
        except wallet_model.DoesNotExist:
            return None
    
    def count_transactions(
        self,
        user_id: int | None = None,
        point_type: str | None = None,
        trans_type: int | None = None,
        transaction_type: str | None = None,
        iid: int | None = None,
        start_date: datetime | str | None = None,
        end_date: datetime | str | None = None,
    ) -> int:
        """
        Count transactions matching the given filters.
        
        Args:
            user_id: Filter by user ID (uid field)
            point_type: Filter by wallet/point type (wtype field)
            trans_type: Filter by transaction type code (trans_type field, integer constant)
            transaction_type: Filter by transaction type ('c' for credit, 'd' for debit) (type field)
            iid: Filter by initiator ID (iid field) - user who performed the transaction (-100 for system)
            start_date: Filter transactions from this date (inclusive) (cdate field)
            end_date: Filter transactions until this date (inclusive) (cdate field)
        
        Returns:
            Number of transactions matching the filters
        """
        wallet_model = self.repository.get_wallet_model()
        queryset = wallet_model.objects.all()
        
        # Apply filters (same as get_transaction_history)
        if user_id is not None:
            queryset = queryset.filter(uid=user_id)
        
        if point_type is not None:
            # Validate point type exists
            valid_types = self.repository.get_point_types()
            if point_type not in valid_types:
                raise InvalidPointTypeError(
                    f"No such point ({point_type}) exists!", point_type=point_type
                )
            queryset = queryset.filter(wtype=point_type)
        
        if trans_type is not None:
            queryset = queryset.filter(trans_type=trans_type)
        
        if transaction_type is not None:
            if transaction_type not in ("c", "d"):
                raise InvalidParamsError(
                    f"Invalid transaction_type: {transaction_type}. Must be 'c' (credit) or 'd' (debit)."
                )
            queryset = queryset.filter(type=transaction_type)
        
        if iid is not None:
            queryset = queryset.filter(iid=iid)
        
        if start_date is not None:
            if isinstance(start_date, str):
                start_date = datetime.fromisoformat(start_date)
            queryset = queryset.filter(cdate__gte=start_date)
        
        if end_date is not None:
            if isinstance(end_date, str):
                end_date = datetime.fromisoformat(end_date)
            queryset = queryset.filter(cdate__lte=end_date)
        
        return queryset.count()

    @staticmethod
    def _get_current_datetime() -> str:
        """Get current datetime as ISO format string."""
        return datetime.now().isoformat()
