I am writing, as an exercice to improve my coding habits, a package that will implement some neural network architectures using PyTorch and that will allow the user to train, test and exploit those architectures.
I have a exploit.py
file that handles command line arguments provided by the user. Among the arguments, the user can provide a configuration file. I parse this configuration file using configparser
. And in this configuration, the user can enter the network architecture that he desires to use.
What I would like to do is to create an instance of the right architecture depending on the user input. I have defined a module named architectures.py
in which I define my architectures. Below is a pseudo-code example of this architecture.py
file :
from torch import nn
class myArchitecture(nn.module):
def __init__(self):
super.__init__()
def forward(self,x)
pass
class myNextArchitecture(nn.module):
def __init__(self):
super.__init__()
def forward(self,x)
pass
so the user could put in the configuration file architecture1
or architecture2
(user-friendly name let's imagine) and I would need to handle the loading of the corresponding instance. Obviously, I could do something like this in the exploit.py
file if I have already parsed the architecture in an arch
variable:
import architecture
if arch == 'architecture1':
net = architecture.myArchitecture()
elif arch == 'architecture2':
net = architecture.myNextArchitecture()
else :
print("Architecture not handled")
# which kind of error should I raise here ?? Should I made a custom one ??
but I am wondering :
- is this a good practice to handle such cases like that ?
- How should I raise the exception in case the user inputs a wrong architecture name ? Should I use a built-in exception and if yes which one ? Or should I create my own ?
Or should I somehow write a more explicit method like architecture.load_model(arch)
? This method I imagine would be located in architecture.py
and would do something like :
def load_model(name):
if name == 'architecture1' :
return myArchitecture() # I assume this would work since located in `architecture.py`