#!/usr/bin/env python3

from setproctitle import setproctitle
from dnslib import *
from dnslib import server
import socket
import netifaces as nf
import struct



localTLDs=[
	'arpa',
	'thc'
]

localDomains=[
	'nexlab.net',
	'nexlab.it'
]

localNShost='127.0.0.1'
localNSport=53

# Customize the port and address of your local server to suit your needs (e.g. localhost -> 0.0.0.0)
proxy_addr = '192.168.42.1'
proxy_ports = { 
      5301: 'tun10',
      5302: 'tun11'
      }

# Customize the address and port of the external DNS server
external_dns_server_addr = '8.8.8.8'
external_dns_server_port = 53


setproctitle("nsproxy")



class NSRecord(DNSRecord):
   
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


    @classmethod
    def question(cls,qname,qtype="A",qclass="IN"):
        """
            Shortcut to create question
            >>> q = DNSRecord.question("www.google.com")
            >>> print(q)
            ;; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: ...
            ;; flags: rd; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 0
            ;; QUESTION SECTION:
            ;www.google.com.                IN      A
            >>> q = DNSRecord.question("www.google.com","NS")
            >>> print(q)
            ;; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: ...
            ;; flags: rd; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 0
            ;; QUESTION SECTION:
            ;www.google.com.                IN      NS
        """
        return NSRecord(q=DNSQuestion(qname,getattr(QTYPE,qtype),
                                             getattr(CLASS,qclass)))

    def send(self,dest,port=53,tcp=False,timeout=None,ipv6=False,from_ip=False,from_port=False):
        """
            Send packet to nameserver and return response
        """
        data = self.pack()
        if ipv6:
            inet = socket.AF_INET6
        else:
            inet = socket.AF_INET
        try:
            srcaddr = False
            if from_ip or from_port:
                if not from_ip:
                    from_ip = ''
                if not from_port:
                    from_port = 0
                srcaddr=(from_ip, from_port)
            sock = None
            if tcp:
                if len(data) > 65535:
                     raise ValueError("Packet length too long: %d" % len(data))
                data = struct.pack("!H",len(data)) + data
                sock = socket.socket(inet,socket.SOCK_STREAM)
                if srcaddr:
                    sock.bind(srcaddr)
                if timeout is not None:
                    sock.settimeout(timeout)
                sock.connect((dest,port))
                sock.sendall(data)
                response = sock.recv(8192)
                length = struct.unpack("!H",bytes(response[:2]))[0]
                while len(response) - 2 < length:
                    response += sock.recv(8192)
                response = response[2:]
            else:
                sock = socket.socket(inet,socket.SOCK_DGRAM)
                if srcaddr:
                    sock.bind(srcaddr)
                if timeout is not None:
                    sock.settimeout(timeout)
                sock.sendto(self.pack(),(dest,port))
                response,server = sock.recvfrom(8192)
        finally:
            if (sock is not None):
                sock.close()

        return response


class RouteResolver:


    def resolve(self, request, handler):
        d = request.reply()
        q = request.get_q()
        q_name = str(q.qname)
        
        try:
            tld = q_name.split(".")[-2]
        except: 
            tld = False
        try:
            domain = q_name.split(".")[-3]+"."+tld
        except:
            domain = False

        if tld and tld in self.tlds or domain and domain in self.domains:
            a = NSRecord.parse(NSRecord.question(q_name).send(self.localns[0], self.localns[1]))
            for rr in a.rr:
                d.add_answer(rr)
        else:
            try:
               srcip = nf.ifaddresses(self.iface)[2][0]['addr']
            except:
               srcip = False
            a = NSRecord.parse(NSRecord.question(q_name).send(external_dns_server_addr, external_dns_server_port, from_ip=srcip))
            for rr in a.rr:
                d.add_answer(rr)
        return d

pservers = {}
for proxy in proxy_ports:
   pport = proxy
   iface = proxy_ports[proxy]
   resolver = RouteResolver()
   resolver.iface = iface
   resolver.domains = localDomains
   resolver.tlds = localTLDs
   resolver.localns = (localNShost, localNSport)
   pservers[proxy] = server.DNSServer(resolver, port=pport, address=proxy_addr)
   pservers[proxy].start_thread()

while True:
    pass
