diff --git a/examples/minimal/wirego_minimal.go b/examples/minimal/wirego_minimal.go index e711554..5aae760 100644 --- a/examples/minimal/wirego_minimal.go +++ b/examples/minimal/wirego_minimal.go @@ -72,6 +72,7 @@ func (WiregoMinimalExample) GetDetectionHeuristicsParents() []string { return []string{"udp", "http"} } +// DetectionHeuristic applies an heuristic to identify the protocol. func (WiregoMinimalExample) DetectionHeuristic(packetNumber int, src string, dst string, layer string, packet []byte) bool { //All packets starting with 0x00 should be passed to our dissector (super advanced heuristic) if len(packet) != 0 && packet[0] == 0x00 { diff --git a/wirego_remote/python/wirego.py b/wirego_remote/python/wirego.py index a8549bb..cef3fa0 100644 --- a/wirego_remote/python/wirego.py +++ b/wirego_remote/python/wirego.py @@ -20,7 +20,7 @@ # FieldId type (just overloading int type) FieldId = NewType('FieldId', int) -#ValueType defines a type of data supported by Wireshark +#ValueType defines a field data type supported by Wireshark class ValueType(IntEnum): ValueTypeNone = 0x01 ValueTypeBool = 0x02 @@ -33,60 +33,63 @@ class ValueType(IntEnum): ValueTypeCString = 0x09 ValueTypeString = 0x10 -#DisplayMode tells Wireshark how to display a field +# DisplayMode tells Wireshark how to display a field class DisplayMode(IntEnum): DisplayModeNone = 0x01 DisplayModeDecimal = 0x02 DisplayModeHexadecimal = 0x03 -#DetectionFilterType defines the type of a declared detection filter +# DetectionFilterType defines the type of a declared detection filter class DetectionFilterType(IntEnum): DetectionFilterTypeInt = 0x01 DetectionFilterTypeStr = 0x02 +# WiregoField holds the description of a field @dataclass class WiregoField: - wirego_field_id: FieldId - name: str - filter: str - value_type: ValueType - display_mode: DisplayMode - + wirego_field_id: FieldId # A user defined unique value identfying this field (enum) + name: str # Field name + filter: str # Field filter + value_type: ValueType # Type of data + display_mode: DisplayMode # Display mode in Wireshark +# DetectionFilter defines a detection filter (ex: tcp.port = 12) @dataclass class DetectionFilter: - filter_type: DetectionFilterType - name: str - value_int: int - value_str: str + filter_type: DetectionFilterType # Type of filter + name: str # Filter name (ex: "tcp.port") + value_int: int # Filter value as int (ex: 12) + value_str: str # or filter value as str (ex: "192.168.1.1"), depending on filter_type +# DissectField holds a dissection result field (refers to a WiregoField and specifies offset+length) @dataclass class DissectField: - wirego_field_id: FieldId - offset: int - length: int - sub_fields: List['DissectField'] + wirego_field_id: FieldId # Field id (Wirego) + offset: int # Field Offset in packet + length: int # Field length + sub_fields: List['DissectField'] # Sub fields +# DissectResult holds a dissection result for a given packet @dataclass class DissectResult: - protocol: str - info: str - fields: List[DissectField] + protocol: str # Protocol column in Wireshark + info: str # Info column in Wireshark + fields: List[DissectField] # List of fields -# Stores a given field from a dissection result +# DissectResultFieldFlatten stores a given field from a dissection result @dataclass class DissectResultFieldFlatten: - parent_idx: int #Index of parent field (for nested fields) - wirego_field_id: FieldId # Field id (Wirego) - offset: int # Field Offset in packet - length: int # Field length + parent_idx: int # Index of parent field (for nested fields) + wirego_field_id: FieldId # Field id (Wirego) + offset: int # Field Offset in packet + length: int # Field length @dataclass class DissectResultFlattenEntry: - protocol: str # Protocol column for Wireshark - info: str # Info column for Wireshark + protocol: str # Protocol column for Wireshark + info: str # Info column for Wireshark fields: List[DissectResultFieldFlatten] # List of fields for Wireshark @@ -127,25 +130,31 @@ def __init__(self, zmq_endpoint: str, verbose: bool, wglistener: WiregoListener) self.zmq_endpoint = zmq_endpoint self.verbose = verbose self.wglistener = wglistener - self.cache_enable = False + self.cache_enable = False # Cache is disabled by default self.cache = {} if self.verbose: logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig(level=logging.WARNING) + # results_cache_enable enables dissection results cache (packets will be dissected only once) def results_cache_enable(self, enable: bool): self.cache_enable = enable + # listen waits for Wirego bridge commands and loop def listen(self): logging.warning("Waiting for Wirego bridge commands...") + + # Locally store some implementation structures self.fields = self.wglistener.get_fields() self.heuristics_parents = self.wglistener.get_detection_heuristics_parents() self.detection_filters = self.wglistener.get_detection_filters() + # Setup ZMQ context = zmq.Context() - socket = context.socket(zmq.REP) + socket = context.socket(zmq.REP) # We're using REQ/REP scheme socket.bind(self.zmq_endpoint) + while True: # Wait for next request from client messageFrames = socket.recv_multipart(0, False, False) @@ -154,8 +163,10 @@ def listen(self): return logging.debug("Received request: %s" % messageFrames) + # First frame contains the command name msg_type = messageFrames[0].bytes.decode('utf-8') logging.debug("-> Message type: "+msg_type) + match msg_type: case "utility_ping\x00": logging.warning("Received ping request from Wirego Bridge.") @@ -195,27 +206,37 @@ def listen(self): socket.send(b"\x00") return + # _utility_get_version returns the Wirego ZMQ API version def _utility_get_version(self, socket, messageFrames): socket.send(b"\x01", zmq.SNDMORE) socket.send(b"\x02", zmq.SNDMORE) socket.send(b"\x00") + # _setup_get_plugin_name returns the plugin name + def _setup_get_plugin_name(self, socket, messageFrames): + socket.send(b"\x01", zmq.SNDMORE) + socket.send(self.wglistener.get_name().encode() + b'\x00') + + # _setup_get_plugin_filter returns the plugin filter def _setup_get_plugin_filter(self, socket, messageFrames): socket.send(b"\x01", zmq.SNDMORE) socket.send(self.wglistener.get_filter().encode() + b'\x00') - def _setup_get_plugin_name(self, socket, messageFrames): + # _setup_get_fields_count returns the number of defined custom fields + def _setup_get_fields_count(self, socket, messageFrames): socket.send(b"\x01", zmq.SNDMORE) - socket.send(self.wglistener.get_name().encode() + b'\x00') + socket.send(len(self.fields).to_bytes(4, 'little')) + # _setup_get_field returns a filter description by index def _setup_get_field(self, socket, messageFrames): if len(messageFrames) != 2: socket.send(b"\x00") return - idx = int.from_bytes(messageFrames[1], 'little') + idx = int.from_bytes(messageFrames[1], 'little')# Frame 1 contains index if idx >= len(self.fields): socket.send(b"\x00") return + # Returns field description socket.send(b"\x01", zmq.SNDMORE) socket.send(self.fields[idx].wirego_field_id.to_bytes(4, 'little'), zmq.SNDMORE) socket.send(self.fields[idx].name.encode() + b'\x00', zmq.SNDMORE) @@ -223,30 +244,16 @@ def _setup_get_field(self, socket, messageFrames): socket.send(self.fields[idx].value_type.to_bytes(4, 'little'), zmq.SNDMORE) socket.send(self.fields[idx].display_mode.to_bytes(4, 'little')) - def _setup_detect_heuristic_parent(self, socket, messageFrames): - if len(messageFrames) != 2: - socket.send(b"\x00") - return - idx = int.from_bytes(messageFrames[1], 'little') - if idx >= len(self.heuristics_parents): - socket.send(b"\x00") - return - socket.send(b"\x01", zmq.SNDMORE) - socket.send(self.heuristics_parents[idx].encode() + b'\x00') - - - def _setup_get_fields_count(self, socket, messageFrames): - socket.send(b"\x01", zmq.SNDMORE) - socket.send(len(self.fields).to_bytes(4, 'little')) - + # _setup_detect_string returns the plugin detection as a string, by index def _setup_detect_string(self, socket, messageFrames): if len(messageFrames) != 2: socket.send(b"\x00") return - idx = int.from_bytes(messageFrames[1], 'little') + idx = int.from_bytes(messageFrames[1], 'little') # Frame 1 contains index if idx >= len(self.fields): socket.send(b"\x00") return + # Iterate over all detection filters and look for strings cnt = 0 for f in self.detection_filters: if f.filter_type == DetectionFilterType.DetectionFilterTypeStr: @@ -260,14 +267,16 @@ def _setup_detect_string(self, socket, messageFrames): #gone too far, no more strings socket.send(b"\x00") + # _setup_detect_int returns the plugin detection as an int, by index def _setup_detect_int(self, socket, messageFrames): if len(messageFrames) != 2: socket.send(b"\x00") return - idx = int.from_bytes(messageFrames[1], 'little') + idx = int.from_bytes(messageFrames[1], 'little') # Frame 1 contains index if idx >= len(self.fields): socket.send(b"\x00") return + # Iterate over all detection filters and look for integers cnt = 0 for f in self.detection_filters: if f.filter_type == DetectionFilterType.DetectionFilterTypeInt: @@ -281,15 +290,32 @@ def _setup_detect_int(self, socket, messageFrames): #gone too far, no more int socket.send(b"\x00") + # _setup_detect_heuristic_parent returns heuristic detection parents by index + def _setup_detect_heuristic_parent(self, socket, messageFrames): + if len(messageFrames) != 2: + socket.send(b"\x00") + return + idx = int.from_bytes(messageFrames[1], 'little') # Frame 1 contains index + if idx >= len(self.heuristics_parents): + socket.send(b"\x00") + return + socket.send(b"\x01", zmq.SNDMORE) + socket.send(self.heuristics_parents[idx].encode() + b'\x00') + + + + # _process_heuristic runs the user defined detection heuristic on a given packet def _process_heuristic(self, socket, messageFrames): if len(messageFrames) != 6: socket.send(b"\x00") return + # Extract packet information from Frames packet_number = messageFrames[1] src = messageFrames[2] dst = messageFrames[3] layer = messageFrames[4] packet_data = messageFrames[5] + # Call the user defined heuristic result = self.wglistener.detection_heuristic(packet_number, src, dst, layer, packet_data) socket.send(b"\x01", zmq.SNDMORE) if result: @@ -297,52 +323,60 @@ def _process_heuristic(self, socket, messageFrames): else: socket.send(b"\x00") + # _process_dissect_packet calls the packet dissection and return a disseciton handler def _process_dissect_packet(self, socket, messageFrames): if len(messageFrames) != 6: socket.send(b"\x00") return + # Extract packet information from Frames pktnum = int.from_bytes(messageFrames[1], 'little') src = messageFrames[2].bytes.decode('utf-8') dst = messageFrames[3].bytes.decode('utf-8') layer = messageFrames[4].bytes.decode('utf-8') packet_data = messageFrames[5] - # Not in cache, dissect packet + # Not in cache, dissect packet. if not pktnum in self.cache: result = self.wglistener.dissect_packet(pktnum, src, dst, layer, packet_data) + # Add dissection result to cache so that we can access it later with result accessors self._add_result_to_cache(result, pktnum) socket.send(b"\x01", zmq.SNDMORE) socket.send(pktnum.to_bytes(4, 'little')) # use pkt number as dissect handler + # _result_get_protocol returns the protocol string for a given dissection handler def _result_get_protocol(self, socket, messageFrames): if len(messageFrames) != 2: socket.send(b"\x00") return - packet_number = int.from_bytes(messageFrames[1], 'little') + packet_number = int.from_bytes(messageFrames[1], 'little') # Frame 1 is the dissection handler + # Retrieve dissection result from cache (index by packet_number) if not packet_number in self.cache: socket.send(b"\x00") return socket.send(b"\x01", zmq.SNDMORE) socket.send(self.cache[packet_number].protocol.encode() + b'\x00') + # _result_get_info returns the info string for a given dissection handler def _result_get_info(self, socket, messageFrames): if len(messageFrames) != 2: socket.send(b"\x00") return - packet_number = int.from_bytes(messageFrames[1], 'little') + packet_number = int.from_bytes(messageFrames[1], 'little') # Frame 1 is the dissection handler + # Retrieve dissection result from cache (index by packet_number) if not packet_number in self.cache: socket.send(b"\x00") return socket.send(b"\x01", zmq.SNDMORE) socket.send(self.cache[packet_number].info.encode() + b'\x00') - + # _result_get_fields_count returns the number of extracted fields for a given dissection handler def _result_get_fields_count(self, socket, messageFrames): if len(messageFrames) != 2: socket.send(b"\x00") return - packet_number = int.from_bytes(messageFrames[1], 'little') + packet_number = int.from_bytes(messageFrames[1], 'little') # Frame 1 is the dissection handler + # Retrieve dissection result from cache (index by packet_number) if not packet_number in self.cache: socket.send(b"\x00") return @@ -350,13 +384,15 @@ def _result_get_fields_count(self, socket, messageFrames): socket.send(b"\x01", zmq.SNDMORE) socket.send(int.to_bytes(count, 4, 'little')) + # _result_get_field returns an extracted fields for a given dissection handler and an index def _result_get_field(self, socket, messageFrames): if len(messageFrames) != 3: socket.send(b"\x00") return - packet_number = int.from_bytes(messageFrames[1], 'little') - idx = int.from_bytes(messageFrames[2], 'little') + packet_number = int.from_bytes(messageFrames[1], 'little') # Frame 1 is the dissection handler + idx = int.from_bytes(messageFrames[2], 'little') # Frame 2 is the field index + # Retrieve dissection result from cache (index by packet_number) if not packet_number in self.cache: socket.send(b"\x00") return @@ -364,19 +400,21 @@ def _result_get_field(self, socket, messageFrames): if idx >= len(result.fields): socket.send(b"\x00") return + # Send custom field contents socket.send(b"\x01", zmq.SNDMORE) socket.send(result.fields[idx].parent_idx.to_bytes(4, byteorder='little', signed=True), zmq.SNDMORE) socket.send(result.fields[idx].wirego_field_id.to_bytes(4, 'little'), zmq.SNDMORE) socket.send(result.fields[idx].offset.to_bytes(4, 'little'), zmq.SNDMORE) socket.send(result.fields[idx].length.to_bytes(4, 'little')) - + # _result_release releases a packet dissection result def _result_release(self, socket, messageFrames): if len(messageFrames) != 2: socket.send(b"\x00") return - packet_number = int.from_bytes(messageFrames[1], 'little') + packet_number = int.from_bytes(messageFrames[1], 'little') # Frame 1 is the dissection handler + # If cache is disabled, remove if so that we may process it again (eventually with more context) if not self.cache_enable: # Remove from cache self.cache.pop(packet_number) @@ -384,6 +422,7 @@ def _result_release(self, socket, messageFrames): socket.send(b"\x01") return + # _add_result_to_cache add a dissection result to the cache def _add_result_to_cache(self, result, pktnum): # Flatten results to a simple list with parenIdx pointing to parent's entry flatten = DissectResultFlattenEntry(result.info, result.info, []) @@ -391,6 +430,7 @@ def _add_result_to_cache(self, result, pktnum): self._add_fields_recursive(flatten, -1, r) self.cache[pktnum] = flatten # Since we have one result per packet number, use pktnum as key + # _add_fields_recursive recursively adds fields and subfields to cache and compute parent_idx def _add_fields_recursive(self, flatten: DissectResultFlattenEntry, parent_idx: int, field: DissectField): new_parent_idx: int field_flatten = DissectResultFieldFlatten(parent_idx, field.wirego_field_id, field.offset, field.length) diff --git a/wirego_remote/python/wirego_minimal.py b/wirego_remote/python/wirego_minimal.py index 4e4108d..9c9bfa6 100644 --- a/wirego_remote/python/wirego_minimal.py +++ b/wirego_remote/python/wirego_minimal.py @@ -43,6 +43,7 @@ def get_detection_heuristics_parents(self): "http", ] + # detection_heuristic applies an heuristic to identify the protocol. def detection_heuristic(self, packet_number: int, src: str, dst: str, stack: str, packet: bytes) -> bool: #All packets starting with 0x00 should be passed to our dissector (super advanced heuristic) if (len(packet) != 0) and (packet[0] == 0x00):