Skip to content
This repository was archived by the owner on Jan 31, 2024. It is now read-only.

Example of i-pi psi connection #240

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions ipi/interfaces/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,38 @@ class Status(object):
Timeout = 32


def get_socket(address="localhost", port=31415, mode="unix", server=False):
"""Create socket.
Parameters:
- server: Driver socket or Client socket?
"""
if mode == "inet":
_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
params = (address, int(port))
_info = "Created inet socket with address " + address + " and port number " + str(port)
elif mode == "unix":
_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
params = "/tmp/ipi_" + address
_info = "Created unix socket with address " + address
else:
raise NameError("Interface mode " + mode + " is not implemented (should be unix/inet)")

if server:
connect = _socket.bind
else:
connect = _socket.connect

try:
connect(params)
info(_info)
except:
if mode == "unix":
raise ValueError("Error opening unix socket. Check if a file " + ("/tmp/ipi_" + address) + " exists, and remove it if unused.")
else:
raise
return _socket


class DriverSocket(socket.socket):

"""Deals with communication between the client and driver code.
Expand Down Expand Up @@ -697,3 +729,102 @@ def poll(self):

self.poll_iter += 1
self.pool_distribute()


class Client(DriverSocket):
"""Deals as starting point for implementing a clien in python.

Deals with sending and receiving the data from the client code.

Attributes:
havedata: Boolean giving whether the client calculated the forces.
"""

def __init__(self, address="localhost", port=31415, mode="unix", _socket=True):
"""Initialises Driver.

Args:
- socket: If a socket should be opened. Can be False for testing purposes.
- address: A string giving the name of the host network.
- port: An integer giving the port the socket will be using.
- mode: A string giving the type of socket used.
"""
if _socket:
# open client socket
_socket = get_socket(address, port, mode, server=False)
else:
_socket = None
super(Client,self).__init__(socket=_socket)
self.MP_VERSION = np.int32(1)
self.extra = ""
self.havedata = False
self.needsinit = True
self.vir = np.zeros((3,3),np.float64)
self.cellh = np.zeros((3,3),np.float64)
self.cellih = np.zeros((3,3),np.float64)
self.nat = np.int32()

# needs to be set from the client
# For now manually set the verbosity
self.verb = verbosity.High


def _getforce(self):
"""Dummy _getforce routine.

This function must be implemented by subclassing or providing a callback function.
This function is assumed to calculate the following:
- self._force: The force of the current positions at self._positions.
- self._potential: The potential of the current positions at self._positions.
"""
if hasattr(self, "callback"):
self._force, self._potential = self.callback(self._positions)
else:
raise NotImplementedError("_getforce must be implemented by providing a self.callback function or overwritten.")


def run(self):
"""Serve forces until asked to finish.

Serve force and potential, that are calculated in the user provided
routine _getforce.
"""
while 1:
msg = self.recv_msg()
if msg == "":
if self.verb > verbosity.Quiet:
print " @CLIENT: Shutting down."
break
elif msg == Message("status"):
if self.needsinit:
self.send_msg("needinit")
elif self.havedata:
self.send_msg("havedata")
else:
self.send_msg("ready")
elif msg == Message("init"):
self.bead = np.int32()
self.bead = self.recvall(self.bead)
str_pars = np.int32()
str_pars = self.recvall(str_pars)
self.pars = " "*str_pars
if str_pars > 0:
self.pars = self.recv(len(self.pars))
self.needsinit = False
elif msg == Message("posdata"):
self.cellh = self.recvall(self.cellh)
self.cellih = self.recvall(self.cellih)
self.nat = self.recvall(self.nat)
self._positions = np.zeros((self.nat,3),np.float64)
self._positions = self.recvall(self._positions)
self._getforce()
self.havedata = True
elif msg == Message("getforce"):
self.sendall(Message("forceready"))
self.sendall(self._potential, 8)
self.sendall(self.nat, 4)
self.sendall(self._force, len(self._force)*8)
self.havedata=False
else:
print >>sys.stderr, " @CLIENT: Couldn't understand command:", msg
break