from typing import Dict, List, Optional

from collections import defaultdict
from urllib.parse import urljoin

from lxml.etree import XMLSyntaxError, fromstring
from requests.exceptions import RequestException

from apadata.api import API
from apadata.api.exceptions import NotFoundRequestError
from apadata.models import UrlTag
from apadata.spacy.constants import LangCode


class SitemapAPI(API):
    """
    Extracts a sitemap for a given domain

    Parameters
    ----------
    domain : str
        Domain for which we will obtain the sitemap and then categorise it
    sitemap : Optional[List[str]]
        Optionally we can pass in the sitemap if it was priorly obtained
    """

    def __init__(
        self,
        domain: str,
        sitemap: Optional[List[str]],
        languages: Optional[List[LangCode]] = None,
    ):
        super().__init__(api_url=f"https://{domain}/", headers={})
        self.domain = domain
        self.sitemap: List[str] = sitemap or []
        self.categories_to_urls: Dict[str, List[str]] = defaultdict(list)
        self.languages = languages or [LangCode("en")]
        self.url_tags = UrlTag.get_translations_by_languages(
            [str(lang.value) for lang in self.languages]
        )

    def get_sitemap_urls(self, url):
        try:
            resp = self.get(endpoint=url)
        except RequestException:
            return []
        try:
            tree = fromstring(resp.content)
            return [loc.text for loc in tree.findall("{*}url/{*}loc")]
        except XMLSyntaxError:
            return []

    def determine_sitemap(self) -> List[str]:
        endings = ["sitemap.xml", "sitemap_index.xml"]
        for ending in endings:
            try:
                self.sitemap.extend(self.get_sitemap_urls(ending))
            except (AttributeError, NotFoundRequestError):
                pass
        if self.sitemap:
            self.sitemap = sorted(list(set(self.sitemap)))
        else:
            self.sitemap = []
            for tag in self.url_tags:
                try:
                    self.get(endpoint=tag)
                except (RequestException, NotFoundRequestError):
                    continue
                self.sitemap.append(urljoin(self.api_url, tag))
        return self.sitemap

    def categorize_sitemap(self) -> None:
        for url in self.sitemap:
            for tag in self.url_tags:
                if tag in url:
                    self.categories_to_urls[tag].append(url)
