diff --git a/t/test.py b/t/test.py index 6c45391..6fa8b52 100644 --- a/t/test.py +++ b/t/test.py @@ -437,9 +437,8 @@ def testServerInsecurePathRelative(self): rrq.options = {} # Start the download. - self.assertRaises( - tftpy.TftpException, serverstate.start, rrq.encode().buffer - ) + with self.assertRaisesRegex(tftpy.TftpException, 'bad file path'): + serverstate.start(rrq.encode().buffer) def testServerInsecurePathRootSibling(self): raddress = "127.0.0.2" @@ -456,9 +455,8 @@ def testServerInsecurePathRootSibling(self): rrq.options = {} # Start the download. - self.assertRaises( - tftpy.TftpException, serverstate.start, rrq.encode().buffer - ) + with self.assertRaisesRegex(tftpy.TftpException, 'bad file path'): + serverstate.start(rrq.encode().buffer) def testServerSecurePathAbsolute(self): raddress = "127.0.0.2" @@ -502,6 +500,27 @@ def testServerSecurePathRelative(self): isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) ) + def testServerPathRoot(self): + raddress = "127.0.0.2" + rport = 10000 + timeout = 5 + with self.dummyServerDir() as d: + root = '/' + serverstate = tftpy.TftpContexts.TftpContextServer( + raddress, rport, timeout, root + ) + rrq = tftpy.TftpPacketTypes.TftpPacketRRQ() + rrq.filename = os.path.join(os.path.abspath(d), "foo", "bar") + rrq.mode = "octet" + rrq.options = {} + + # Start the download. + serverstate.start(rrq.encode().buffer) + # Should be in expectack state. + self.assertTrue( + isinstance(serverstate.state, tftpy.TftpStates.TftpStateExpectACK) + ) + def testServerDownloadWithStopNow(self, output="/tmp/out"): log.debug("===> Running testcase testServerDownloadWithStopNow") root = os.path.dirname(os.path.abspath(__file__)) diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index dd217b6..cc7ee91 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -295,8 +295,11 @@ def serverInitial(self, pkt, raddress, rport): # (e.g. '..') and ensure that is still within the server's # root directory self.full_path = os.path.abspath(full_path) - log.debug("full_path is %s", full_path) - if self.full_path.startswith(os.path.normpath(self.context.root) + os.sep): + # Determine root path, replace double slashes with single + # // => / + root_path = (os.path.normpath(self.context.root) + os.sep).replace('//', '/') + log.debug("full_path is %s, server root is %s" % (self.full_path, root_path)) + if self.full_path.startswith(root_path): log.info("requested file is in the server root - good") else: log.warning("requested file is not within the server root - bad")