'''
Created on 2011-10-03
@author: jacekf

Common routing classes, regardless of whether used in HTTP or multiprocess context
'''
from collections import defaultdict
from corepost import Response, RESTException
from corepost.enums import Http, HttpHeader
from corepost.utils import getMandatoryArgumentNames, safeDictUpdate
from corepost.convert import convertForSerialization, generateXml, convertToJson
from corepost.filters import IRequestFilter, IResponseFilter

from .enums import MediaType
from twisted.internet import defer
from twisted.web.http import parse_qs
from twisted.python import log
import re, copy, yaml,json, logging
from xml.etree import ElementTree
import uuid


class UrlRouter:
    ''' Common class for containing info related to routing a request to a function '''
    
    __urlMatcher = re.compile(r"<(int|float|uuid|):?([^/]+)>")
    __urlRegexReplace = {"":r"(?P<arg>([^/]+))","int":r"(?P<arg>\d+)","float":r"(?P<arg>\d+.?\d*)","uuid":r"(?P<arg>[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12})"}
    __typeConverters = {"int":int,"float":float,"uuid":uuid.UUID}
    
    def __init__(self,f,url,methods,accepts,produces,cache):
        self.__f = f
        self.__url = url
        self.__methods = methods if isinstance(methods,tuple) else (methods,)
        self.__accepts = accepts if isinstance(accepts,tuple) else (accepts,)
        self.__produces = produces
        self.__cache = cache
        self.__argConverters = {} # dict of arg names -> group index
        self.__validators = {}
        self.__mandatory = getMandatoryArgumentNames(f)[2:]
        
    def compileMatcherForFullUrl(self):
        """Compiles the regex matches once the URL has been updated to include the full path from the parent class"""
        #parse URL into regex used for matching
        m = UrlRouter.__urlMatcher.findall(self.url)
        self.__matchUrl = "^%s$" % self.url
        for match in m:
            if len(match[0]) == 0:
                # string
                self.__argConverters[match[1]] = None
                self.__matchUrl = self.__matchUrl.replace("<%s>" % match[1],
                                    UrlRouter.__urlRegexReplace[match[0]].replace("arg",match[1]))
            else:
                # non string
                self.__argConverters[match[1]] = UrlRouter.__typeConverters[match[0]]
                self.__matchUrl = self.__matchUrl.replace("<%s:%s>" % match,
                                    UrlRouter.__urlRegexReplace[match[0]].replace("arg",match[1]))

        self.__matcher = re.compile(self.__matchUrl)
        
        
    @property
    def cache(self):
        '''Indicates if this URL should be cached or not'''
        return self.__cache    

    @property
    def methods(self):
        return self.__methods
    
    @property
    def url(self):
        return self.__url

    @property
    def accepts(self):
        return self.__accepts

    def addValidator(self,fieldName,validator):
        '''Adds additional field-specific formencode validators'''
        self.__validators[fieldName] = validator
        
    def getArguments(self,url):
        '''
        Returns None if nothing matched (i.e. URL does not match), empty dict if no args found (i,e, static URL)
        or dict with arg/values for dynamic URLs
        '''
        g = self.__matcher.search(url)
        if g != None:
            args = g.groupdict()
            # convert to expected datatypes
            if len(args) > 0:
                for name in list(args.keys()):
                    converter = self.__argConverters[name]
                    if converter != None:
                        args[name] = converter(args[name])
            return args
        else:
            return None
        
    def call(self,instance,request,**kwargs):
        '''Forwards call to underlying method'''
        for arg in self.__mandatory:
            if arg not in kwargs:
                raise TypeError("Missing mandatory argument '%s'" % arg)
        return self.__f(instance,request,**kwargs)
    
    def __str__(self):
        return "%s %s" % (self.url, self.methods) 

class UrlRouterInstance():
    """Combines a UrlRouter with a class instance it should be executed against"""
    def __init__(self,clazz,urlRouter):
        self.clazz = clazz
        self.urlRouter = urlRouter
        
    def __str__(self):
        return self.urlRouter.url

class CachedUrl:
    '''
    Used for caching URLs that have been already routed once before. Avoids the overhead
    of regex processing on every incoming call for commonly accessed REST URLs
    '''
    def __init__(self,urlRouterInstance,args):
        self.__urlRouterInstance = urlRouterInstance
        self.__args = args
        
    @property
    def urlRouterInstance(self):
        return self.__urlRouterInstance
    
    @property
    def args(self):
        return self.__args
    
class RequestRouter:
    '''
    Class that handles request->method routing functionality to any type of resource
    '''
    
    def __init__(self,restServiceContainer,schema=None,filters=()):
        '''
        Constructor
        '''
        self.__urls = {Http.GET: defaultdict(dict),Http.POST: defaultdict(dict),Http.PUT: defaultdict(dict),Http.DELETE: defaultdict(dict),Http.OPTIONS: defaultdict(dict),Http.PATCH: defaultdict(dict),Http.HEAD: defaultdict(dict)}
        self.__cachedUrls = {Http.GET: defaultdict(dict),Http.POST: defaultdict(dict),Http.PUT: defaultdict(dict),Http.DELETE: defaultdict(dict),Http.OPTIONS: defaultdict(dict),Http.PATCH: defaultdict(dict),Http.HEAD: defaultdict(dict)}
        self.__urlRouterInstances = {}
        self.__schema = schema
        self.__urlsMehods = {}
        self.__registerRouters(restServiceContainer)
        self.__urlContainer = restServiceContainer
        self.__requestFilters = []
        self.__responseFilters = []

        if filters != None:
            for webFilter in filters:
                valid = False
                if IRequestFilter.providedBy(webFilter):
                    self.__requestFilters.append(webFilter)
                    valid = True
                if IResponseFilter.providedBy(webFilter):
                    self.__responseFilters.append(webFilter)
                    valid = True

                if not valid:
                    raise RuntimeError("filter %s must implement IRequestFilter or IResponseFilter" % webFilter.__class__.__name__)

    @property
    def path(self):
        return self.__path

    def __registerRouters(self, restServiceContainer):
        """Main method responsible for registering routers"""
        from types import FunctionType

        for service in restServiceContainer.services:
            # check if the service has a root path defined, which is optional
            rootPath = service.__class__.path if "path" in service.__class__.__dict__ else ""
            
            for key in service.__class__.__dict__:
                func = service.__class__.__dict__[key]
                # handle REST resources directly on the CorePost resource
                if type(func) == FunctionType and hasattr(func,'corepostRequestRouter'):
                    # if specified, add class path to each function's path
                    rq = func.corepostRequestRouter
                    #workaround for multiple passes of __registerRouters (for unit tests etc)
                    if not hasattr(rq, 'urlAdapted'):
                        rq.url = "%s%s" % (rootPath,rq.url)
                        # remove first and trailing '/' to standardize URLs
                        start = 1 if rq.url[0:1] == "/" else 0
                        end =  -1 if rq.url[len(rq.url) -1] == '/' else len(rq.url)
                        rq.url = rq.url[start:end]
                        setattr(rq,'urlAdapted',True)

                    # now that the full URL is set, compile the matcher for it
                    rq.compileMatcherForFullUrl()
                    for method in rq.methods:
                        for accepts in rq.accepts:
                            urlRouterInstance = UrlRouterInstance(service,rq)
                            self.__urls[method][rq.url][accepts] = urlRouterInstance
                            self.__urlRouterInstances[func] = urlRouterInstance # needed so that we can lookup the urlRouterInstance for a specific function
                            if self.__urlsMehods.get(rq.url, None) is None:
                                self.__urlsMehods[rq.url] = []
                            self.__urlsMehods[rq.url].append(method)

    def getResponse(self,request):
        """Finds the appropriate instance and dispatches the request to the registered function. Returns the appropriate Response object"""
        # see if already cached
        response = None
        try:
            if len(self.__requestFilters) > 0:
                self.__filterRequests(request)

            # standardize URL and remove trailing "/" if necessary
            standardized_postpath = request.postpath if (len(request.postpath)==0 or request.postpath[-1] != '' or request.postpath == ['']) else request.postpath[:-1]
            path = '/'.join(standardized_postpath) 

            contentType =  MediaType.WILDCARD if HttpHeader.CONTENT_TYPE not in request.received_headers else request.received_headers[HttpHeader.CONTENT_TYPE]       

            urlRouterInstance, pathargs = None, None
            # fetch URL arguments <-> function from cache if hit at least once before
            if contentType in self.__cachedUrls[request.method][path]:
                cachedUrl = self.__cachedUrls[request.method][path][contentType]
                urlRouterInstance,pathargs = cachedUrl.urlRouterInstance, cachedUrl.args 
            else:
                # first time this URL is called
                instance = None

                # go through all the URLs, pick up the ones matching by content type
                # and then validate which ones match by path/argument to a particular UrlRouterInstance
                for contentTypeInstances in list(self.__urls[request.method].values()):

                    if contentType in contentTypeInstances:
                        # there is an exact function for this incoming content type
                        instance = contentTypeInstances[contentType]
                    elif MediaType.WILDCARD in contentTypeInstances:
                        # fall back to any wildcard method
                        instance = contentTypeInstances[MediaType.WILDCARD]

                    if instance != None:
                        # see if the path arguments match up against any function @route definition
                        args = instance.urlRouter.getArguments(path)
                        if args != None:
                           
                            if instance.urlRouter.cache:
                                self.__cachedUrls[request.method][path][contentType] = CachedUrl(instance, args)
                            urlRouterInstance,pathargs = instance,args
                            break
            #actual call
            if urlRouterInstance != None and pathargs != None:
                allargs = copy.deepcopy(pathargs)
                
                try:
                    # if POST/PUT, check if we need to automatically parse JSON, YAML, XML
                    self.__parseRequestData(request)
                    # parse request arguments from form or JSON docss
                    self.__addRequestArguments(request, allargs)
                    urlRouter = urlRouterInstance.urlRouter
                    val = urlRouter.call(urlRouterInstance.clazz,request,**allargs)
                 
                    #handle Deferreds natively
                    if isinstance(val,defer.Deferred):
                        # add callback to finish the request
                        val.addCallback(self.__finishDeferred,request)
                        val.addErrback(self.__finishDeferredError,request)
                        return val
                    else:
                        #special logic for POST to return 201 (created)
                        if request.method == Http.POST:
                            if hasattr(request, 'code'):
                                if request.code == 200:
                                    request.setResponseCode(201) 
                            else:
                                request.setResponseCode(201)
                        
                        response = self.__generateResponse(request, val, request.code)
                    
                except TypeError as ex:
                    log.msg(ex,logLevel=logging.WARN)
                    response = self.__createErrorResponse(request,400,"%s" % ex)

                except RESTException as ex:
                    """Convert REST exceptions to their responses. Input errors log at a lower level to avoid overloading logs"""
                    if (ex.response.code in (400,404)):
                        log.msg(ex,logLevel=logging.WARN)
                    else:
                        log.err(ex)
                    response = ex.response

                except Exception as ex:
                    log.err(ex)
                    response =  self.__createErrorResponse(request,500,"Unexpected server error: %s\n%s" % (type(ex),ex))
                    
            #if a url is defined, but not the requested method
            elif not request.method in self.__urlsMehods.get(path, []) and self.__urlsMehods.get(path, []) != []:
                
                response = self.__createErrorResponse(request,501, "")
            else:
                log.msg("URL %s not found" % path,logLevel=logging.WARN)
                response = self.__createErrorResponse(request,404,"URL '%s' not found\n" % request.path)
        
        except Exception as ex:
            log.err(ex)
            response = self.__createErrorResponse(request,500,"Internal server error: %s" % ex)
        
        # response handling
        if response != None and len(self.__responseFilters) > 0:
            self.__filterResponses(request,response)

        return response
    
    def __generateResponse(self,request,response,code=200):
        """
        Takes care of automatically rendering the response and converting it to appropriate format (text,XML,JSON,YAML)
        depending on what the caller can accept. Returns Response
        """
        if isinstance(response, str):
            return Response(code,response,{HttpHeader.CONTENT_TYPE: MediaType.TEXT_PLAIN})
        elif isinstance(response, Response):
            return response
        else:
            (content,contentType) = self.__convertObjectToContentType(request, response)
            return Response(code,content,{HttpHeader.CONTENT_TYPE:contentType})

    def __convertObjectToContentType(self,request,obj):
        """
        Takes care of converting an object (non-String) response to the appropriate format, based on the what the caller can accept.
        Returns a tuple of (content,contentType)
        """
        obj = convertForSerialization(obj)

        if HttpHeader.ACCEPT in request.received_headers:
            accept = request.received_headers[HttpHeader.ACCEPT]
            if MediaType.APPLICATION_JSON in accept:
                return (convertToJson(obj),MediaType.APPLICATION_JSON)
            elif MediaType.TEXT_YAML in accept:
                return (yaml.dump(obj),MediaType.TEXT_YAML)
            elif MediaType.APPLICATION_XML in accept or MediaType.TEXT_XML in accept:
                return (generateXml(obj),MediaType.APPLICATION_XML)
            else:
                # no idea, let's do JSON
                return (convertToJson(obj),MediaType.APPLICATION_JSON)
        else:
            # called has no accept header, let's default to JSON
            return (convertToJson(obj),MediaType.APPLICATION_JSON)

    def __finishDeferred(self,val,request):
        """Finishes any Defered/inlineCallback methods. Returns Response"""
        if isinstance(val,Response):
            return val
        elif val != None:
            try:
                return self.__generateResponse(request,val)
            except Exception as ex:
                msg = "Unexpected server error: %s\n%s" % (type(ex),ex)
                return self.__createErrorResponse(request, 500, msg)
        else:
            return Response(209,None)

    def __finishDeferredError(self,error,request):
        """Finishes any Defered/inlineCallback methods that raised an error. Returns Response"""
        log.err(error, "Deferred failed")
        return self.__createErrorResponse(request, 500,"Internal server error")

    def __createErrorResponse(self,request,code,message):
        """Common method for rendering errors"""
        return Response(code=code, entity=message, headers={"content-type": MediaType.TEXT_PLAIN})
 
    def __parseRequestData(self,request):
        '''Automatically parses JSON,XML,YAML if present'''
        if request.method in (Http.POST,Http.PUT) and HttpHeader.CONTENT_TYPE in list(request.received_headers.keys()):
            contentType = request.received_headers["content-type"]
            request.data = request.content.read()

            if contentType == MediaType.APPLICATION_JSON:
                try:
                    request.json = json.loads(request.data) if request.data else {}
                except Exception as ex:
                    raise TypeError("Unable to parse JSON body: %s" % ex)
            elif contentType in (MediaType.APPLICATION_XML,MediaType.TEXT_XML):
                try: 
                    request.xml = ElementTree.XML(request.data)
                except Exception as ex:
                    raise TypeError("Unable to parse XML body: %s" % ex)
            elif contentType == MediaType.TEXT_YAML:
                try: 
                    request.yaml = yaml.safe_load(request.data)
                except Exception as ex:
                    raise TypeError("Unable to parse YAML body: %s" % ex)

    def __addRequestArguments(self,request,allargs):
        """Parses the request form arguments OR JSON document root elements to build the list of arguments to a method"""
        # handler for weird Twisted logic where PUT does not get form params
        # see: http://twistedmatrix.com/pipermail/twisted-web/2007-March/003338.html
        requestargs = request.args

        if request.method == Http.PUT and HttpHeader.CONTENT_TYPE in list(request.received_headers.keys()) \
            and request.received_headers[HttpHeader.CONTENT_TYPE] == MediaType.APPLICATION_FORM_URLENCODED:
            # request.data is populated in __parseRequestData
            requestargs = parse_qs(request.data, 1)

        #merge form args
        if len(list(requestargs.keys())) > 0:
            for arg in list(requestargs.keys()):
                # maintain first instance of an argument always
                safeDictUpdate(allargs,arg,requestargs[arg][0])
        elif hasattr(request,'json'):
            # if YAML parse root elements instead of form elements   
            for key in list(request.json.keys()):
                safeDictUpdate(allargs, key, request.json[key])
        elif hasattr(request,'yaml'):
            # if YAML parse root elements instead of form elements   
            for key in list(request.yaml.keys()):
                safeDictUpdate(allargs, key, request.yaml[key])
        elif hasattr(request,'xml'):
            # if XML, parse attributes first, then root nodes
            for key in request.xml.attrib:
                safeDictUpdate(allargs, key, request.xml.attrib[key])
            for el in request.xml.findall("*"):
                safeDictUpdate(allargs, el.tag,el.text)
        
            
    def __filterRequests(self,request):
        """Filters incoming requests"""
        for webFilter in self.__requestFilters:
            webFilter.filterRequest(request)
            
    def __filterResponses(self,request,response):
        """Filters incoming requests"""
        for webFilter in self.__responseFilters:
            webFilter.filterResponse(request,response)