diff --git a/ipi/interfaces/sockets.py b/ipi/interfaces/sockets.py index 112396a9..7e52e0b6 100644 --- a/ipi/interfaces/sockets.py +++ b/ipi/interfaces/sockets.py @@ -91,6 +91,38 @@ class Status(object): Timeout = 32 +def get_socket(address="localhost", port=31415, mode="unix", server=False): + """Create socket. + Parameters: + - server: Driver socket or Client socket? + """ + if mode == "inet": + _socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + params = (address, int(port)) + _info = "Created inet socket with address " + address + " and port number " + str(port) + elif mode == "unix": + _socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + params = "/tmp/ipi_" + address + _info = "Created unix socket with address " + address + else: + raise NameError("Interface mode " + mode + " is not implemented (should be unix/inet)") + + if server: + connect = _socket.bind + else: + connect = _socket.connect + + try: + connect(params) + info(_info) + except: + if mode == "unix": + raise ValueError("Error opening unix socket. Check if a file " + ("/tmp/ipi_" + address) + " exists, and remove it if unused.") + else: + raise + return _socket + + class DriverSocket(socket.socket): """Deals with communication between the client and driver code. @@ -697,3 +729,102 @@ def poll(self): self.poll_iter += 1 self.pool_distribute() + + +class Client(DriverSocket): + """Deals as starting point for implementing a clien in python. + + Deals with sending and receiving the data from the client code. + + Attributes: + havedata: Boolean giving whether the client calculated the forces. + """ + + def __init__(self, address="localhost", port=31415, mode="unix", _socket=True): + """Initialises Driver. + + Args: + - socket: If a socket should be opened. Can be False for testing purposes. + - address: A string giving the name of the host network. + - port: An integer giving the port the socket will be using. + - mode: A string giving the type of socket used. + """ + if _socket: + # open client socket + _socket = get_socket(address, port, mode, server=False) + else: + _socket = None + super(Client,self).__init__(socket=_socket) + self.MP_VERSION = np.int32(1) + self.extra = "" + self.havedata = False + self.needsinit = True + self.vir = np.zeros((3,3),np.float64) + self.cellh = np.zeros((3,3),np.float64) + self.cellih = np.zeros((3,3),np.float64) + self.nat = np.int32() + + # needs to be set from the client + # For now manually set the verbosity + self.verb = verbosity.High + + + def _getforce(self): + """Dummy _getforce routine. + + This function must be implemented by subclassing or providing a callback function. + This function is assumed to calculate the following: + - self._force: The force of the current positions at self._positions. + - self._potential: The potential of the current positions at self._positions. + """ + if hasattr(self, "callback"): + self._force, self._potential = self.callback(self._positions) + else: + raise NotImplementedError("_getforce must be implemented by providing a self.callback function or overwritten.") + + + def run(self): + """Serve forces until asked to finish. + + Serve force and potential, that are calculated in the user provided + routine _getforce. + """ + while 1: + msg = self.recv_msg() + if msg == "": + if self.verb > verbosity.Quiet: + print " @CLIENT: Shutting down." + break + elif msg == Message("status"): + if self.needsinit: + self.send_msg("needinit") + elif self.havedata: + self.send_msg("havedata") + else: + self.send_msg("ready") + elif msg == Message("init"): + self.bead = np.int32() + self.bead = self.recvall(self.bead) + str_pars = np.int32() + str_pars = self.recvall(str_pars) + self.pars = " "*str_pars + if str_pars > 0: + self.pars = self.recv(len(self.pars)) + self.needsinit = False + elif msg == Message("posdata"): + self.cellh = self.recvall(self.cellh) + self.cellih = self.recvall(self.cellih) + self.nat = self.recvall(self.nat) + self._positions = np.zeros((self.nat,3),np.float64) + self._positions = self.recvall(self._positions) + self._getforce() + self.havedata = True + elif msg == Message("getforce"): + self.sendall(Message("forceready")) + self.sendall(self._potential, 8) + self.sendall(self.nat, 4) + self.sendall(self._force, len(self._force)*8) + self.havedata=False + else: + print >>sys.stderr, " @CLIENT: Couldn't understand command:", msg + break