diff --git a/docs/data/torch.html b/docs/data/torch.html
index 9d6eada..c39c947 100644
--- a/docs/data/torch.html
+++ b/docs/data/torch.html
@@ -938,12 +938,15 @@
Module proteinflow.data.torch
Parameters
----------
- cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}
- The CDR to be iterated over. Set to `None` to go back to iterating over all chains.
+ cdr : list | str | None
+ The CDR to be iterated over (choose from H1, H2, H3, L1, L2, L3).
+ Set to `None` to go back to iterating over all chains.
"""
if not self.sabdab:
cdr = None
+ if isinstance(cdr, str):
+ cdr = [cdr]
if cdr == self.cdr:
return
self.cdr = cdr
@@ -954,12 +957,12 @@ Module proteinflow.data.torch
print(f"Setting CDR to {cdr}...")
for i, data in tqdm(enumerate(self.data)):
if self.clusters is not None:
- if data.split("__")[1] == cdr:
+ if data.split("__")[1] in cdr:
self.indices.append(i)
else:
add = False
for chain in self.files[data]:
- if chain.split("__")[1] == cdr:
+ if chain.split("__")[1] in cdr:
add = True
break
if add:
@@ -1091,7 +1094,7 @@ Module proteinflow.data.torch
id = self.data[idx] # data is already filtered by length
chain_id = random.choice(list(self.files[id].keys()))
if self.cdr is not None:
- while self.cdr != chain_id.split("__")[1]:
+ while chain_id.split("__")[1] not in self.cdr:
chain_id = random.choice(list(self.files[id].keys()))
else:
cluster = self.data[idx]
@@ -1937,12 +1940,15 @@ Parameters
Parameters
----------
- cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}
- The CDR to be iterated over. Set to `None` to go back to iterating over all chains.
+ cdr : list | str | None
+ The CDR to be iterated over (choose from H1, H2, H3, L1, L2, L3).
+ Set to `None` to go back to iterating over all chains.
"""
if not self.sabdab:
cdr = None
+ if isinstance(cdr, str):
+ cdr = [cdr]
if cdr == self.cdr:
return
self.cdr = cdr
@@ -1953,12 +1959,12 @@ Parameters
print(f"Setting CDR to {cdr}...")
for i, data in tqdm(enumerate(self.data)):
if self.clusters is not None:
- if data.split("__")[1] == cdr:
+ if data.split("__")[1] in cdr:
self.indices.append(i)
else:
add = False
for chain in self.files[data]:
- if chain.split("__")[1] == cdr:
+ if chain.split("__")[1] in cdr:
add = True
break
if add:
@@ -2090,7 +2096,7 @@ Parameters
id = self.data[idx] # data is already filtered by length
chain_id = random.choice(list(self.files[id].keys()))
if self.cdr is not None:
- while self.cdr != chain_id.split("__")[1]:
+ while chain_id.split("__")[1] not in self.cdr:
chain_id = random.choice(list(self.files[id].keys()))
else:
cluster = self.data[idx]
@@ -2210,8 +2216,9 @@ Methods
Set the CDR to be iterated over (only for SAbDab datasets).
Parameters
-cdr
: {"H1", "H2", "H3", "L1", "L2", "L3"}
-- The CDR to be iterated over. Set to
None
to go back to iterating over all chains.
+cdr
: list | str | None
+- The CDR to be iterated over (choose from H1, H2, H3, L1, L2, L3).
+Set to
None
to go back to iterating over all chains.
@@ -2222,12 +2229,15 @@ Parameters
Parameters
----------
- cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}
- The CDR to be iterated over. Set to `None` to go back to iterating over all chains.
+ cdr : list | str | None
+ The CDR to be iterated over (choose from H1, H2, H3, L1, L2, L3).
+ Set to `None` to go back to iterating over all chains.
"""
if not self.sabdab:
cdr = None
+ if isinstance(cdr, str):
+ cdr = [cdr]
if cdr == self.cdr:
return
self.cdr = cdr
@@ -2238,12 +2248,12 @@ Parameters
print(f"Setting CDR to {cdr}...")
for i, data in tqdm(enumerate(self.data)):
if self.clusters is not None:
- if data.split("__")[1] == cdr:
+ if data.split("__")[1] in cdr:
self.indices.append(i)
else:
add = False
for chain in self.files[data]:
- if chain.split("__")[1] == cdr:
+ if chain.split("__")[1] in cdr:
add = True
break
if add: