# -*- coding: utf-8 -*- from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC class ElementIn(object): default_find_fn = 'find_element' def __init__(self, element, locator, find_fn=None): self.element = element self.locator = locator self.find_fn = self.default_find_fn if find_fn is None else find_fn def __call__(self, driver): try: return getattr(self.element, self.find_fn)(*self.locator) except: return False class ElementsIn(ElementIn): default_find_fn = 'find_elements' class SeleniumExtensions(object): @classmethod def wait_elem(cls, driver, locator, inside_el=None, wait_time=10): if inside_el: return WebDriverWait(driver, wait_time).until(ElementIn(inside_el, locator)) return WebDriverWait(driver, wait_time).until( EC.presence_of_element_located(locator) ) @classmethod def wait_elems(cls, driver, locator, inside_el=None, wait_time=10): if inside_el: return WebDriverWait(driver, wait_time).until(ElementsIn(inside_el, locator)) return WebDriverWait(driver, wait_time).until( EC.presence_of_all_elements_located(locator) ) @classmethod def wait_for_js_load(cls, driver, wait_time=10): WebDriverWait(driver, wait_time).until( lambda driver: driver.execute_script('return document.readyState') == 'complete') @classmethod def add_requests_log(cls, driver): js = ''' (function() { var open = XMLHttpRequest.prototype.open; var send = XMLHttpRequest.prototype.send; window._requestsLog = []; window._findRequest = function(pathOrXHR){ for(var i=window._requestsLog.length - 1; i >= 0; i--){ var request = window._requestsLog[i]; if(typeof(pathOrXHR) == 'string' && request[1].indexOf(pathOrXHR) > -1 || request[2] === pathOrXHR){ return request; } } } XMLHttpRequest.prototype.open = function(method, url){ window._requestsLog.push([method, url, this]); return open.apply(this, arguments); } XMLHttpRequest.prototype.send = function(body){ var r = window._findRequest(this); if(r){ r.push(body); } return send.apply(this, arguments); } })(); ''' return driver.execute_script(js) @classmethod def get_requests(cls, driver): return driver.execute_script('return window._requestsLog;') @classmethod def wait_for_request(cls, driver, path, wait_time=10): WebDriverWait(driver, wait_time).until( lambda driver: driver.execute_script(''' var r = window._findRequest(arguments[0]); return !!r && r[2].readyState == 4; ''', path) is True) return driver.execute_script('return window._findRequest(arguments[0]);', path)