#!/usr/bin/python3

from setproctitle import setproctitle
from shell_cmd import sh
from nfstream import NFStreamer
import time
import sys, signal
import json

deftimeout="3600"

setproctitle("dpi")

online_streamer = NFStreamer(source="br0", promiscuous_mode=False, splt_analysis=20, statistical_analysis=False)
 
templconf = """
{
  "Cats":{

     "Network":{
        "nostart": [],
        "noapps": [],
        "knownapps": ["DHCP", "DHCPV6", "DNScrypt", "DoH_DoT", "DNS", "Ookla", "ICMP", "ICMPV6", "IGMP", "LLMNR", "MDNS", "DoH_DoT", "Ookla", "WSD"],
        "ipset": "system_triplet",
        "timeout":"3600"
     },

     "Game":{
        "nostart": [],
        "noapps": [],
        "knownapps": ["Steam", "Xbox", "Playstation"],
        "ipset": "streaming_triplet",
        "timeout":"3600"
     },

     "Template-GroupName":{
        "nostart": ["DNS", "ICMP"],
        "noapps": [],
        "knownapps": [],
        "ipset": "name_ipset",
        "timeout":"3600"
     }
  },

  "Apps":{
     "TikTok":{
        "ipset": "kids_triplet",
        "nostart": [],
        "timeout": "3600",
        "knowstarts":"TLS"
     }
  }
}
"""

try:
   fconf = open("/etc/nexdpi/dpirules.json", "r")
   R=json.loads(fconf.read())
   fconf.close()
except:
   print(" Cant start without a rules file.\n\n")
   print(" Please use the following example to create a JSON config file as /etc/nexdpi/dpirules.json\n\n")
   print(templconf)
   sys.exit(0)


Cats = R['Cats']
Apps = R['Apps']


def reloadconf(signum, frame):
   global Cats
   global Apps

   try:
      fconf = open("/etc/nexdpi/dpirules.json", "r")
      R=json.loads(fconf.read())
      fconf.close()
      Cats = R['Cats']
      Apps = R['Apps']
      print("Rules file reloaded")
   except:
      print("Error loading rules file.")



signal.signal(signal.SIGHUP, reloadconf)

Cats = R['Cats']
Apps = R['Apps']

UnknownMatch=[]

class NexDPI():

   fullname=False
   isknown=False

   def __init__(self):
      print(time.asctime(), "NexDPI created")
      

   def main(self):
      print(time.asctime(), "NexDPI started")
      for flow in online_streamer:

         if self.fullname and not self.fullname in UnknownMatch and not self.isknown:
            print(self.fullname)
            UnknownMatch.append(self.fullname)

         self.isknown = False
         self.fullname=flow.application_name+" "+flow.application_category_name
         triplet=str(flow.dst_ip)+","+str(flow.dst_port)+","+str(flow.src_ip)
         cname = flow.application_category_name
         aname = flow.application_name
         sername = aname.split(".")[-1:][0]
         ipv=flow.ip_version

         if cname in list(Cats.keys()):
            if ipv==6:
               ipset_list = Cats[cname]['ipset']+"6"
            else:
               ipset_list = Cats[cname]['ipset']
            if not aname.startswith(tuple(Cats[cname]['nostart'])) and not sername in list(Cats[cname]['noapps']):
               sh("ipset test "+ipset_list+" "+triplet+" >/dev/null 2>&1 || ipset add "+ipset_list+" "+triplet+" timeout "+Cats[cname]['timeout']+" > /dev/null 2>&1")
               if sername in list(Cats[cname]['knownapps']):
                  self.isknown = True
               continue
         
         if sername in list(Apps.keys()):
            appd=Apps[sername]
            if ipv==6:
               ipset_list = appd['ipset']+"6"
            else:
               ipset_list = appd['ipset']
            if not aname.startswith(tuple(appd['nostart'])):
               sh("ipset test "+ipset_list+" "+triplet+" >/dev/null 2>&1 || ipset add "+ipset_list+" "+triplet+" timeout "+appd['timeout']+" > /dev/null 2>&1")
               if aname.startswith(tuple(appd['knowstarts'])):
                  self.isknown = True
               continue


         #print(flow.application_is_guessed)
         #print(flow.src_ip)
         #print(flow.src_port)
         #print(flow.dst_ip)
         #print(flow.dst_port)
         #print(flow.requested_server_name)

         #time.sleep(0.1)



if __name__ == "__main__":
   import sys
   ndpi=NexDPI()
   ndpi.main()
   
