from torchview import draw_graph model_graph = draw_graph(model, input_size=(1, 1, 28, 28), save_graph=True, expand_nested=True)