Skip to content

Commit

Permalink
test case
Browse files Browse the repository at this point in the history
  • Loading branch information
ajing committed Jul 5, 2017
1 parent 1347414 commit 8136100
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
35 changes: 35 additions & 0 deletions split_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import pandas as pds

def split_train_validation(num_val=3000):
"""
Save train image names and validation image names to csv files
"""
train_image_idx = np.sort(np.random.choice(40479, 40479-num_val, replace=False))
all_idx = np.arange(40479)
validation_image_idx = np.zeros(num_val, dtype=np.int32)
val_idx = 0
train_idx = 0
for i in all_idx:
if i not in train_image_idx:
validation_image_idx[val_idx] = i
val_idx += 1
else:
train_idx += 1
# save train
train = []
for name in train_image_idx:
train.append('train_%s' % name)

eval = []
for name in validation_image_idx:
eval.append('train_%s' % name)

df = pds.DataFrame(train)
df.to_csv('dataset/train-%s' % (40479 - num_val), index=False, header=False)

df = pds.DataFrame(eval)
df.to_csv('dataset/validation-%s' % num_val, index=False, header=False)


split_train_validation(num_val=3000)
89 changes: 89 additions & 0 deletions test_script.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from datasets import load_img\n",
"from torchvision.transforms import *"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"img = load_img(\"/home/ubuntu/Kaggle/AmazonForest/data/train-jpg/train_40471.jpg\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"tranform = Compose([Scale(256),ToTensor()])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "'int' object is not iterable",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-b5702241243e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtranform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/miniconda3/envs/gpu_env/lib/python3.6/site-packages/torchvision-0.1.8-py3.6.egg/torchvision/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/gpu_env/lib/python3.6/site-packages/torchvision-0.1.8-py3.6.egg/torchvision/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 130\u001b[0;31m \u001b[0mw\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 131\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mh\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mw\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: 'int' object is not iterable"
]
}
],
"source": [
"tranform(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 8136100

Please sign in to comment.