Skip to content

Commit

Permalink
added overfit test, updated documentation, fixed yapf and flake8 erro…
Browse files Browse the repository at this point in the history
…rs (deepchem#4184)

* added overfit test, updated documentation, fixed yapf and flake8 errors

* minor fix
  • Loading branch information
KitVB authored Nov 25, 2024
1 parent b2a130e commit 4a29b98
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 19 deletions.
2 changes: 1 addition & 1 deletion deepchem/data/tests/test_deepvariant_pileup_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_pileup(self):

# Assert the number of reads
self.assertEqual(len(image_dataset), 15)
self.assertEqual(image_dataset.X[0].shape, (100, 221, 6))
self.assertEqual(image_dataset.X[0].shape, (299, 299, 6))


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions deepchem/feat/deepvariant_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
try:
import dgl
import torch
import pysam
except ImportError:
pass

Expand Down Expand Up @@ -606,8 +607,8 @@ def seq_to_int(seq):

def fast_pass_aligner(self, assembled_region: Dict[str, Any]) -> List[Any]:
"""
Align reads to the haplotype of the assembled region using Striped Smith
Waterman algorithm.
Align reads to the haplotype of the assembled region using Striped
Smith Waterman algorithm.
Parameters
----------
Expand Down Expand Up @@ -807,7 +808,8 @@ def _featurize(self, datapoint):
decoded_sequences.append(decoded_seq)

# Map the sequences to chrom names
chrom_names = ["chr1", "chr2"]
with pysam.FastaFile(reference_file_path) as fasta_file:
chrom_names = fasta_file.references

reference_seq_dict = {
chrom_names[i]: seq for i, seq in enumerate(decoded_sequences)
Expand Down
23 changes: 16 additions & 7 deletions deepchem/feat/deepvariant_pileup_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from deepchem.data import ImageDataset
from typing import List

try:
import pysam
except ImportError:
pass


class PileupFeaturizer(Featurizer):
"""
Expand All @@ -23,13 +28,16 @@ class PileupFeaturizer(Featurizer):
>>> from deepchem.feat import RealignerFeaturizer, PileupFeaturizer
>>> bamfile_path = 'deepchem/data/tests/example.bam'
>>> reference_path = 'deepchem/data/tests/sample.fa'
>>> realigner_feat = RealignerFeaturizer()
>>> windows_haplotypes = realigner_feat.featurize((bamfile_path, reference_path))
>>> realigner= RealignerFeaturizer()
>>> windows_haplotypes = realigner.featurize((bamfile_path,reference_path))
>>> pileup_feat = PileupFeaturizer()
>>> features = pileup_feat.featurize((windows_haplotypes, reference_path))
>>> features
<ImageDataset X.shape: (15, 100, 221, 6), y.shape: (15,), w.shape: (15,), ids: [0 1 2 ... 12 13 14], task_names: [0]>
Note
----
This class requires pysam to be installed. Pysam can be used with
Linux or MacOS X. To use Pysam on Windows, use Windows Subsystem for
Linux(WSL).
"""

Expand Down Expand Up @@ -91,7 +99,8 @@ def _featurize(self, datapoint):
decoded_sequences.append(decoded_seq)

# Map the sequences to chrom names
chrom_names = ["chr1", "chr2"]
with pysam.FastaFile(reference_file_path) as fasta_file:
chrom_names = fasta_file.references

reference_seq_dict = {
chrom_names[i]: seq for i, seq in enumerate(decoded_sequences)
Expand Down Expand Up @@ -132,8 +141,8 @@ def get_supports_variant_intensity(read, haplotype):
def get_diff_from_ref_intensity(base, ref_base):
return 1.0 if base != ref_base else 0.25

height = 221
width = 100
height = 299
width = 299
num_channels = 6

images = []
Expand Down
19 changes: 11 additions & 8 deletions deepchem/models/torch_models/inceptionv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@


class InceptionV3(nn.Module):
"""
InceptionV3 model architecture for image classification.
"""

def __init__(self,
num_classes=1000,
aux_logits=True,
num_classes: int = 1000,
aux_logits: bool = True,
transform_input: bool = False,
in_channels=6,
dropout_rate=0.5):
in_channels: int = 6,
dropout_rate: float = 0.5) -> None:
super(InceptionV3, self).__init__()
self.aux_logits = aux_logits

Expand Down Expand Up @@ -85,7 +88,7 @@ def __init__(self,
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)

def forward(self, x):
def forward(self, x: torch.Tensor):
# N x 6 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
Expand Down Expand Up @@ -770,16 +773,16 @@ class InceptionV3Model(TorchModel):
def __init__(self,
in_channels=6,
warmup_steps=10000,
learning_rate=0.001,
learning_rate=0.064,
dropout_rate=0.2,
decay_rate=0.94,
**kwargs):
# Fixed hyperparameters
decay_steps = 2 # epochs per decay
decay_rate = 0.947
rho = 0.9
momentum = 0.9
epsilon = 1.0
# weight_decay = 0.00004
dropout_rate = 0.2

# Initialize the InceptionV3 model architecture
model = InceptionV3(num_classes=3,
Expand Down
48 changes: 48 additions & 0 deletions deepchem/models/torch_models/tests/test_inceptionv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,51 @@ def test_InceptionAux():

# Verify the output shape matches expected shape
assert output_tensor.shape == expected_output_shape


@pytest.mark.torch
def test_inceptionv3_overfit():
"""
Test the InceptionV3 model's ability to overfit on a small dataset
"""
from sklearn.metrics import accuracy_score
from deepchem.models.torch_models.inceptionv3 import InceptionV3Model

# Generate a small dataset to test overfitting
input_shape = (3, 3, 299, 299)
input_samples = np.random.randn(*input_shape).astype(np.float32)
output_samples = np.array([0, 1,
2]).astype(np.int64) # One sample per class
one_hot_output_samples = one_hot_encode(output_samples, 3)

dataset = dc.data.ImageDataset(input_samples, one_hot_output_samples)

# Initialize model and set a temporary directory for saving
model_dir = tempfile.mkdtemp()
inception_model = InceptionV3Model(n_tasks=3,
in_channels=3,
warmup_steps=0,
learning_rate=0.1,
decay_rate=1,
dropout_rate=0.0,
model_dir=model_dir)

# Train for many epochs to test overfitting capability
inception_model.fit(dataset, nb_epoch=100)

# Check performance on the small dataset
pred = inception_model.predict(dataset)

# Ensure predictions are in the correct shape to match the number
# of classes
assert pred.shape == (3, 3), f"Unexpected prediction shape: {pred.shape}"

# Convert predictions and labels to one-hot format for metric computation
pred_labels = np.argmax(pred, axis=1)
true_labels = output_samples

# Calculate accuracy using direct comparison
accuracy = accuracy_score(true_labels, pred_labels)

# Assert the accuracy is high, indicating overfitting
assert accuracy > 0.9, "Failed to overfit on small dataset"

0 comments on commit 4a29b98

Please sign in to comment.