from typing import List, Optional, Tuple
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import HTTPException, status

from app.models.product import Product
from app.models.violation import Violation
from app.models.scraping_result import ScrapingResult
from app.schemas.product import ProductCreate, ProductUpdate
from app.services.pricing_service import calculate_pack_prices


class ProductService:
    @staticmethod
    async def get_violation_count(db: AsyncSession, product_id: int) -> int:
        """Count violations for a specific product."""
        # Get product first
        product_result = await db.execute(select(Product).where(Product.id == product_id))
        product = product_result.scalars().first()
        
        if not product:
            return 0
        
        # Count violations matching this product's name
        count_result = await db.execute(
            select(func.count(Violation.id)).where(
                Violation.product_name == product.product_name
            )
        )
        return count_result.scalar() or 0

    @staticmethod
    async def enrich_product_with_violations(db: AsyncSession, product: Product) -> Product:
        """Add violation_count to a product object by querying violations matching product name."""
        # Count violations matching this product's name
        count_result = await db.execute(
            select(func.count(Violation.id)).where(
                Violation.product_name == product.product_name
            )
        )
        violation_count = count_result.scalar() or 0
        product.violation_count = violation_count  # type: ignore
        return product

    @staticmethod
    async def get_products(
        db: AsyncSession,
        page: int = 1,
        limit: int = 10,
        sort_by: str = "product_name",
        search: Optional[str] = None,
    ) -> Tuple[List[Product], int]:
        offset = (page - 1) * limit
        query = select(Product)

        if search:
            query = query.where(Product.product_name.ilike(f"%{search}%"))

        # Sorting logic
        if sort_by == "msp":
            query = query.order_by(Product.msp)
        elif sort_by == "product_name":
            query = query.order_by(Product.product_name)
        elif sort_by == "last_scraped_date":
            # For now sorting by updated_at as placeholder for last_scraped_date
            query = query.order_by(Product.updated_at.desc())
        else:
            query = query.order_by(Product.product_name)

        # Get total count for pagination
        count_query = select(func.count()).select_from(query.subquery())
        total = await db.scalar(count_query)

        # Apply pagination
        query = query.offset(offset).limit(limit)
        result = await db.execute(query)
        products = result.scalars().all()

        # Enrich products with violation counts
        enriched_products = []
        for product in products:
            product = await ProductService.enrich_product_with_violations(db, product)
            enriched_products.append(product)

        return enriched_products, total

    @staticmethod
    async def create_product(db: AsyncSession, product_in: ProductCreate) -> Product:
        """Create a new product with auto-calculated pack prices."""
        # Calculate pack prices based on MSP
        pack_prices = calculate_pack_prices(float(product_in.msp))
        
        db_product = Product(
            reference_id=product_in.reference_id,
            product_name=product_in.product_name,
            barcode=product_in.barcode,
            msp=product_in.msp,
            status=product_in.status,
            # NEW: Set calculated pack prices
            price_1_pack=pack_prices[1],
            price_2_pack=pack_prices[2],
            price_3_pack=pack_prices[3],
            price_4_pack=pack_prices[4],
            price_5_pack=pack_prices[5],
            price_6_pack=pack_prices[6],
            price_12_pack=pack_prices[12],
        )
        db.add(db_product)
        await db.commit()
        await db.refresh(db_product)
        return db_product

    @staticmethod
    async def update_product(
        db: AsyncSession, product_id: int, product_in: ProductUpdate
    ) -> Product:
        """Update product. If MSP changes, recalculate pack prices."""
        product = await ProductService.get_product_by_id(db, product_id)
        
        update_data = product_in.model_dump(exclude_unset=True)
        
        # If MSP is being updated, recalculate all pack prices
        if "msp" in update_data:
            new_msp = float(update_data["msp"])
            pack_prices = calculate_pack_prices(new_msp)
            update_data.update({
                "price_1_pack": pack_prices[1],
                "price_2_pack": pack_prices[2],
                "price_3_pack": pack_prices[3],
                "price_4_pack": pack_prices[4],
                "price_5_pack": pack_prices[5],
                "price_6_pack": pack_prices[6],
                "price_12_pack": pack_prices[12],
            })
        
        for field, value in update_data.items():
            setattr(product, field, value)
        
        await db.commit()
        await db.refresh(product)
        return product

    @staticmethod
    async def delete_product(db: AsyncSession, product_id: int) -> None:
        result = await db.execute(select(Product).where(Product.id == product_id))
        db_product = result.scalars().first()
        if not db_product:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Product not found",
            )

        await db.delete(db_product)
        await db.commit()

    @staticmethod
    async def get_product_by_id(db: AsyncSession, product_id: int) -> Product:
        result = await db.execute(select(Product).where(Product.id == product_id))
        db_product = result.scalars().first()
        if not db_product:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Product not found",
            )
        # Enrich product with violation count
        db_product = await ProductService.enrich_product_with_violations(db, db_product)
        return db_product
