"""Default Django repository implementation for wallet operations."""

from __future__ import annotations

from decimal import Decimal
from typing import TYPE_CHECKING, Protocol

from django.db import connection, transaction
from django.db.models import F

if TYPE_CHECKING:
    from django.db.models import Model


class WalletRepositoryProtocol(Protocol):
    """
    Interface for wallet operations.
    
    For most use cases, use WalletRepository which implements all methods automatically.
    You only need to define point types and their decimal places.
    
    Custom implementations must implement all methods below.
    """

    def get_point_types(self) -> dict[str, int]:
        """
        Return a dictionary mapping point types to their decimal places.
        
        Example:
            {
                "credit_balance": 2,
                "reward_points": 0,
                "crypto_balance": 8,
            }
        """
        ...

    def update_balance(
        self,
        user_id: int,
        point_type: str,
        amount: Decimal,
        allow_negative: bool = False,
    ) -> Decimal:
        """
        Update user balance by adding amount (positive for add, negative for deduct).
        
        Returns the new balance after the update.
        Should be atomic (use database transactions).
        """
        ...

    def deduct_balance_atomic(
        self,
        user_id: int,
        point_type: str,
        amount: Decimal,
        allow_negative: bool = False,
    ) -> tuple[bool, Decimal]:
        """
        Atomically deduct balance using SQL WHERE clause to prevent race conditions.
        
        Uses SQL: UPDATE ... WHERE {point_type} >= {amount} (or >= 0 if allow_negative)
        
        Returns:
            Tuple of (success: bool, new_balance: Decimal)
            - success: True if row was updated (sufficient balance), False otherwise
            - new_balance: The balance after deduction (if successful) or current balance (if failed)
        
        Should be atomic (use database transactions).
        """
        ...

    def get_user_balance(self, user_id: int, point_type: str) -> Decimal:
        """
        Get the current balance for a specific point type for a user.
        
        Used for reading balance after operations.
        """
        ...

    def create_transaction_record(self, record) -> int:
        """
        Create a transaction record and return the new transaction ID.
        
        Args:
            record: TransactionRecord dataclass instance
        
        Should be atomic (use database transactions).
        """
        ...


class WalletRepository:
    """
    Default Django implementation of WalletRepositoryProtocol.
    
    This repository handles all balance operations and transaction recording automatically.
    Users only need to define point types and their decimal places - that's it!
    
    Example:
        # Using User model for balances (default)
        repo = WalletRepository(
            user_model=User,
            point_types={
                "credit_balance": 2,
                "reward_points": 0,
            },
        )
        
        # Using separate Wallet model for balances
        repo = WalletRepository(
            user_model=User,
            wallet_balance_model=Wallet,
            point_types={
                "credit_balance": 2,
                "reward_points": 0,
            },
        )
    """

    def __init__(
        self,
        user_model: type[Model],
        point_types: dict[str, int],
        wallet_balance_model: type[Model] | None = None,
        wallet_model: type[Model] | None = None,
        point_type_field_map: dict[str, str] | None = None,
        wallet_user_id_field: str = "user_id",
    ):
        """
        Initialize the repository.
        
        Args:
            user_model: Django User model class (used for reference, or as balance model if wallet_balance_model not provided)
            point_types: Dictionary mapping point types to decimal places
            wallet_balance_model: Optional separate model for storing wallet balances (must have user_id field).
                                 If None, balances are stored on user_model.
            wallet_model: Optional custom WalletTransaction model (defaults to WalletTransaction)
            point_type_field_map: Optional mapping from point_type to database field name
            wallet_user_id_field: Field name in wallet_balance_model that references user_id (default: "user_id")
        """
        # Lazy import to avoid AppRegistryNotReady error
        if wallet_model is None:
            from .models import WalletTransaction
            wallet_model = WalletTransaction
        
        self._user_model = user_model
        self._wallet_balance_model = wallet_balance_model
        self._point_types = point_types
        self._wallet_model = wallet_model
        self._point_type_field_map = point_type_field_map or {}
        self._wallet_user_id_field = wallet_user_id_field
        
        # Determine which model to use for balance operations
        self._balance_model = wallet_balance_model or user_model
        self._balance_model_is_separate = wallet_balance_model is not None

    def get_point_types(self) -> dict[str, int]:
        """Return point types and their decimal places."""
        return self._point_types

    def get_user_model(self) -> type[Model]:
        """Return the User model class."""
        return self._user_model

    def get_wallet_model(self) -> type[Model] | None:
        """Return the WalletTransaction model class."""
        return self._wallet_model

    def get_wallet_balance_model(self) -> type[Model]:
        """Return the model used for storing wallet balances."""
        return self._balance_model

    def get_point_type_field(self, point_type: str) -> str:
        """Return the database field name for a point type."""
        return self._point_type_field_map.get(point_type, point_type)
    
    def _get_balance_object(self, user_id: int, for_update: bool = False):
        """
        Get the balance object (wallet or user) for a given user_id.
        Creates the object if it doesn't exist ONLY for system users (user_id == -100).
        For regular users, raises DoesNotExist if user doesn't exist.
        
        Args:
            user_id: The user ID
            for_update: Whether to use select_for_update() for locking
            
        Returns:
            The balance model instance
            
        Raises:
            DoesNotExist: If user doesn't exist and user_id is not -100 (system user)
        """
        SYSTEM_USER_ID = -100
        
        if self._balance_model_is_separate:
            # Use separate wallet model with user_id field
            if for_update:
                # For update, we need to get or create within a transaction
                # First try to get with select_for_update
                try:
                    return self._balance_model.objects.select_for_update().get(
                        **{self._wallet_user_id_field: user_id}
                    )
                except self._balance_model.DoesNotExist:
                    # Only auto-create for system users
                    if user_id == SYSTEM_USER_ID:
                        return self._balance_model.objects.create(
                            **{self._wallet_user_id_field: user_id}
                        )
                    raise
            else:
                # Only auto-create for system users
                if user_id == SYSTEM_USER_ID:
                    obj, created = self._balance_model.objects.get_or_create(
                        **{self._wallet_user_id_field: user_id},
                        defaults={}
                    )
                    return obj
                else:
                    return self._balance_model.objects.get(
                        **{self._wallet_user_id_field: user_id}
                    )
        else:
            # Use user model directly
            if for_update:
                # For update, we need to get or create within a transaction
                # First try to get with select_for_update
                try:
                    return self._user_model.objects.select_for_update().get(pk=user_id)
                except self._user_model.DoesNotExist:
                    # Only auto-create for system users
                    if user_id == SYSTEM_USER_ID:
                        # For user model, we need to provide required fields
                        defaults = {}
                        if hasattr(self._user_model, 'username'):
                            defaults['username'] = f"system_user_{user_id}"
                        if hasattr(self._user_model, 'email'):
                            defaults['email'] = f"system_{user_id}@system.local"
                        return self._user_model.objects.create(pk=user_id, **defaults)
                    raise
            else:
                # Try to get first
                try:
                    return self._user_model.objects.get(pk=user_id)
                except self._user_model.DoesNotExist:
                    # Only auto-create for system users
                    if user_id == SYSTEM_USER_ID:
                        defaults = {}
                        if hasattr(self._user_model, 'username'):
                            defaults['username'] = f"system_user_{user_id}"
                        if hasattr(self._user_model, 'email'):
                            defaults['email'] = f"system_{user_id}@system.local"
                        return self._user_model.objects.create(pk=user_id, **defaults)
                    raise

    @transaction.atomic
    def update_balance(
        self,
        user_id: int,
        point_type: str,
        amount: Decimal,
        allow_negative: bool = False,
    ) -> Decimal:
        """
        Update user balance by adding amount (positive for add, negative for deduct).
        
        Returns the new balance after the update.
        """
        field_name = self.get_point_type_field(point_type)
        
        # Check if balance would go negative (before update)
        if not allow_negative:
            balance_obj = self._get_balance_object(user_id, for_update=False)
            current_balance = getattr(balance_obj, field_name)
            if current_balance + amount < 0:
                raise ValueError(f"Insufficient balance: {current_balance} < {abs(amount)}")

        # Use F() expression for atomic database-level update to avoid rounding issues
        balance_model = self._wallet_balance_model or self._user_model
        
        # Determine the correct field name for filtering by user_id
        if self._balance_model_is_separate:
            # Use wallet_user_id_field for separate wallet model
            user_id_field = self._wallet_user_id_field
        else:
            # Use 'id' or 'pk' for user model
            user_id_field = "id"
        
        # Get the filter kwargs
        filter_kwargs = {user_id_field: user_id}
        
        # Perform atomic update using F() expression
        balance_model.objects.filter(**filter_kwargs).update(
            **{field_name: F(field_name) + amount}
        )
        
        # Get updated balance
        balance_obj = self._get_balance_object(user_id, for_update=False)
        new_balance = getattr(balance_obj, field_name)
        
        return new_balance

    @transaction.atomic
    def deduct_balance_atomic(
        self,
        user_id: int,
        point_type: str,
        amount: Decimal,
        allow_negative: bool = False,
    ) -> tuple[bool, Decimal]:
        """
        Atomically deduct balance using SQL WHERE clause to prevent race conditions.
        
        Uses SQL: UPDATE ... WHERE {point_type} >= {amount} (or >= 0 if allow_negative)
        
        Returns:
            Tuple of (success: bool, new_balance: Decimal)
            - success: True if row was updated (sufficient balance), False otherwise
            - new_balance: The balance after deduction (if successful) or current balance (if failed)
        """
        field_name = self.get_point_type_field(point_type)
        table_name = self._balance_model._meta.db_table

        # Handle zero amount - no need to update, just return current balance
        if amount == 0:
            current_balance = self.get_user_balance(user_id, point_type)
            return True, current_balance

        # Build WHERE clause based on whether we're using separate wallet model
        if self._balance_model_is_separate:
            # Use user_id field for separate wallet model
            if allow_negative:
                where_clause = f"{self._wallet_user_id_field} = %s"
                where_params = [user_id]
            else:
                where_clause = f"{self._wallet_user_id_field} = %s AND {field_name} >= %s"
                where_params = [user_id, amount]
            id_field = self._wallet_user_id_field
        else:
            # Use id field for user model
            if allow_negative:
                where_clause = "id = %s"
                where_params = [user_id]
            else:
                where_clause = f"id = %s AND {field_name} >= %s"
                where_params = [user_id, amount]
            id_field = "id"

        with connection.cursor() as cursor:
            # Update balance atomically
            sql = f"""
                UPDATE {table_name}
                SET {field_name} = {field_name} - %s
                WHERE {where_clause}
            """
            cursor.execute(sql, [amount] + where_params)

            # Check if row was updated (rowCount > 0)
            rows_affected = cursor.rowcount

            if rows_affected > 0:
                # Get new balance
                cursor.execute(
                    f"SELECT {field_name} FROM {table_name} WHERE {id_field} = %s",
                    [user_id]
                )
                row = cursor.fetchone()
                new_balance = Decimal(str(row[0]))
                return True, new_balance
            else:
                # Get current balance for return value
                cursor.execute(
                    f"SELECT {field_name} FROM {table_name} WHERE {id_field} = %s",
                    [user_id]
                )
                row = cursor.fetchone()
                if row:
                    current_balance = Decimal(str(row[0]))
                else:
                    # Row doesn't exist, create it and return 0
                    self._get_balance_object(user_id, for_update=True)
                    current_balance = Decimal("0")
                return False, current_balance

    def get_user_balance(self, user_id: int, point_type: str) -> Decimal:
        """Get the current balance for a specific point type for a user."""
        field_name = self.get_point_type_field(point_type)
        balance_obj = self._get_balance_object(user_id, for_update=False)
        return getattr(balance_obj, field_name)

    @transaction.atomic
    def create_transaction_record(self, record) -> int:
        """
        Create a transaction record and return the new transaction ID.
        
        Args:
            record: TransactionRecord dataclass instance
        """
        from datetime import datetime
        
        # Convert ISO string to datetime if needed
        cdate = record.cdate
        if isinstance(cdate, str):
            cdate = datetime.fromisoformat(cdate)
        
        # Build the model instance
        transaction_obj = self._wallet_model.objects.create(
            wtype=record.wtype,
            iid=record.iid,
            uid=record.uid,
            type=record.type,
            amount=record.amount,
            balance=record.balance,
            trans_type=record.trans_type,
            descr=record.descr,
            cdate=cdate,
            extra_data=record.extra_data,
        )
        return transaction_obj.id
